Detailed changes
@@ -5144,6 +5144,7 @@ dependencies = [
"ctor",
"db",
"edit_prediction_context",
+ "edit_prediction_metrics",
"edit_prediction_types",
"feature_flags",
"fs",
@@ -5206,6 +5207,7 @@ dependencies = [
"debug_adapter_extension",
"dirs",
"edit_prediction",
+ "edit_prediction_metrics",
"extension",
"flate2",
"fs",
@@ -5280,6 +5282,19 @@ dependencies = [
"zeta_prompt",
]
+[[package]]
+name = "edit_prediction_metrics"
+version = "0.1.0"
+dependencies = [
+ "indoc",
+ "language",
+ "pretty_assertions",
+ "serde",
+ "similar",
+ "tree-sitter",
+ "zeta_prompt",
+]
+
[[package]]
name = "edit_prediction_types"
version = "0.1.0"
@@ -57,6 +57,7 @@ members = [
"crates/edit_prediction",
"crates/edit_prediction_cli",
"crates/edit_prediction_context",
+ "crates/edit_prediction_metrics",
"crates/edit_prediction_types",
"crates/edit_prediction_ui",
"crates/editor",
@@ -480,6 +481,7 @@ zed_actions = { path = "crates/zed_actions" }
zed_credentials_provider = { path = "crates/zed_credentials_provider" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
+edit_prediction_metrics = { path = "crates/edit_prediction_metrics" }
zeta_prompt = { path = "crates/zeta_prompt" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
@@ -31,6 +31,7 @@ credentials_provider.workspace = true
db.workspace = true
edit_prediction_types.workspace = true
edit_prediction_context.workspace = true
+edit_prediction_metrics.workspace = true
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
@@ -1,10 +1,8 @@
-mod kept_rate;
-mod tokenize;
-mod tree_sitter;
+pub use edit_prediction_metrics::KeptRateResult;
-pub use kept_rate::KeptRateResult;
-#[cfg(test)]
-pub use kept_rate::TokenAnnotation;
-pub use kept_rate::compute_kept_rate;
-pub(crate) use tokenize::tokenize;
-pub use tree_sitter::count_tree_sitter_errors;
+pub use edit_prediction_metrics::compute_kept_rate;
+use language::SyntaxLayer;
+
+pub fn count_tree_sitter_errors<'a>(layers: impl Iterator<Item = SyntaxLayer<'a>>) -> usize {
+ edit_prediction_metrics::count_tree_sitter_errors(layers.map(|layer| layer.node()))
+}
@@ -1,88 +0,0 @@
-use language::SyntaxLayer;
-
-pub fn count_tree_sitter_errors<'a>(layers: impl Iterator<Item = SyntaxLayer<'a>>) -> usize {
- let mut total_count: usize = 0;
- for layer in layers {
- let node = layer.node();
- let mut cursor = node.walk();
- 'layer: loop {
- let current = cursor.node();
- if current.is_error() || current.is_missing() {
- total_count += 1;
- }
- if current.has_error() && cursor.goto_first_child() {
- continue;
- }
- if cursor.goto_next_sibling() {
- continue;
- }
- loop {
- if !cursor.goto_parent() {
- break 'layer;
- }
- if cursor.goto_next_sibling() {
- continue;
- }
- }
- }
- }
- total_count
-}
-
-#[cfg(test)]
-mod tests {
- use std::ops::Range;
-
- use super::count_tree_sitter_errors;
- use gpui::{AppContext as _, TestAppContext};
- use language::{Buffer, BufferSnapshot, rust_lang};
-
- fn error_count_in_range(edited_buffer_snapshot: &BufferSnapshot, range: Range<usize>) -> usize {
- let layers = edited_buffer_snapshot.syntax_layers_for_range(range, true);
- count_tree_sitter_errors(layers)
- }
-
- fn rust_snapshot(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
- while buffer.read_with(cx, |buffer, _| buffer.is_parsing()) {
- cx.run_until_parked();
- }
- buffer.read_with(cx, |buffer, _| buffer.snapshot())
- }
-
- #[gpui::test]
- async fn counts_no_errors_for_valid_rust(cx: &mut TestAppContext) {
- let text = "fn helper(value: usize) -> usize {\n value + 1\n}\n";
- let snapshot = rust_snapshot(text, cx);
-
- assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 0);
- }
-
- #[gpui::test]
- async fn counts_errors_for_invalid_rust(cx: &mut TestAppContext) {
- let text = "fn helper(value: usize) -> usize {\n let total = ;\n total\n}\n";
- let snapshot = rust_snapshot(text, cx);
-
- assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 1);
- }
-
- #[gpui::test]
- async fn counts_no_errors_for_subrange_of_valid_rust(cx: &mut TestAppContext) {
- let text = "fn first() -> usize {\n let value = 1;\n value + 1\n}\n";
- let snapshot = rust_snapshot(text, cx);
- let body_start = text.find("let value").unwrap();
- let body_end = body_start + "let value = 1;".len();
-
- assert_eq!(error_count_in_range(&snapshot, body_start..body_end), 0);
- }
-
- #[gpui::test]
- async fn counts_errors_for_subrange_of_invalid_rust(cx: &mut TestAppContext) {
- let text = "fn second() -> usize {\n let broken = ;\n broken\n}\n";
- let snapshot = rust_snapshot(text, cx);
- let error_start = text.find("let broken = ;").unwrap();
- let error_end = error_start + "let broken = ;".len();
-
- assert_eq!(error_count_in_range(&snapshot, error_start..error_end), 1);
- }
-}
@@ -61,6 +61,7 @@ terminal_view.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["cli-support"] }
+edit_prediction_metrics.workspace = true
telemetry_events.workspace = true
wasmtime.workspace = true
zeta_prompt.workspace = true
@@ -1,1301 +1,24 @@
-use collections::HashMap;
+#![allow(unused_imports)]
+
+use crate::example::ActualCursor;
+
+pub use edit_prediction_metrics::ClassificationMetrics;
+pub use edit_prediction_metrics::Counts;
+pub use edit_prediction_metrics::DeltaChrFMetrics;
+pub use edit_prediction_metrics::KeptRateResult;
+pub use edit_prediction_metrics::TokenChangeCounts;
+pub use edit_prediction_metrics::braces_disbalance;
+pub use edit_prediction_metrics::compute_kept_rate;
+pub use edit_prediction_metrics::count_patch_token_changes;
+pub use edit_prediction_metrics::delta_chr_f;
+pub use edit_prediction_metrics::delta_chr_f_beta;
+pub use edit_prediction_metrics::exact_lines_match;
+pub use edit_prediction_metrics::extract_changed_lines_from_diff;
+pub use edit_prediction_metrics::is_editable_region_correct;
-use crate::{
- example::ActualCursor,
- reorder_patch::{Patch, PatchLine},
- word_diff::{DiffOp, diff_tokens, tokenize},
-};
-
-pub type Counts = HashMap<String, usize>;
-type CountsDelta = HashMap<String, isize>;
-
-/// Context characters needed on each side of a change to capture all affected n-grams
-const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1;
-
-#[derive(Default, Debug, Clone)]
-pub struct ClassificationMetrics {
- pub true_positives: usize,
- pub false_positives: usize,
- pub false_negatives: usize,
-}
-
-impl ClassificationMetrics {
- pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
- let mut true_positives = 0;
- let mut false_positives = 0;
- let mut false_negatives = 0;
-
- for (ngram, &expected_count) in expected {
- let actual_count = *actual.get(ngram).unwrap_or(&0);
- if actual_count > expected_count {
- false_positives += actual_count - expected_count;
- } else {
- false_negatives += expected_count - actual_count;
- }
- true_positives += expected_count.min(actual_count);
- }
-
- for (ngram, &actual_count) in actual {
- if !expected.contains_key(ngram) {
- false_positives += actual_count;
- }
- }
-
- ClassificationMetrics {
- true_positives,
- false_positives,
- false_negatives,
- }
- }
-
- pub fn accumulate(&mut self, other: &ClassificationMetrics) {
- self.true_positives += other.true_positives;
- self.false_positives += other.false_positives;
- self.false_negatives += other.false_negatives;
- }
-
- pub fn precision(&self) -> f64 {
- if self.true_positives + self.false_positives == 0 {
- 0.0
- } else {
- self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
- }
- }
-
- pub fn recall(&self) -> f64 {
- if self.true_positives + self.false_negatives == 0 {
- 0.0
- } else {
- self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
- }
- }
-
- pub fn f1(&self) -> f64 {
- let precision = self.precision();
- let recall = self.recall();
- if precision + recall == 0.0 {
- 0.0
- } else {
- 2.0 * precision * recall / (precision + recall)
- }
- }
-}
-
-enum ChrfWhitespace {
- /// Preserve whitespace as-is
- #[allow(unused)]
- Unchanged,
-
- /// Ignore all whitespace differences
- #[allow(unused)]
- Ignore,
-
- /// Collapse whitespace into single spaces
- Collapse,
-}
-
-const CHR_F_CHAR_ORDER: usize = 6;
-const CHR_F_BETA: f64 = 0.5;
-const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse;
-
-pub fn delta_chr_f_beta() -> f64 {
- CHR_F_BETA
-}
-
-#[derive(Default, Debug, Clone)]
-pub struct DeltaChrFMetrics {
- pub score: f64,
- pub beta: f64,
- pub counts: ClassificationMetrics,
- pub precision: f64,
- pub recall: f64,
-}
-
-/// Computes delta-chrF metrics that compare two sets of edits.
-///
-/// This metric works by:
-/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
-/// 2. Comparing these deltas to measure how well actual edits match expected edits
-///
-/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
-/// the expected edits.
-pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
- if original == expected && expected == actual {
- return DeltaChrFMetrics {
- score: 100.0,
- beta: CHR_F_BETA,
- precision: 1.0,
- recall: 1.0,
- ..DeltaChrFMetrics::default()
- };
- }
-
- let orig_chars: Vec<char> = filter_whitespace_chars(original);
- let exp_chars: Vec<char> = filter_whitespace_chars(expected);
- let act_chars: Vec<char> = filter_whitespace_chars(actual);
-
- // Find the changed regions between original→expected and original→actual
- // We only need to compute n-grams on these regions (plus context for boundary n-grams)
- let (orig_for_exp, exp_region) = extract_changed_regions(&orig_chars, &exp_chars);
- let (orig_for_act, act_region) = extract_changed_regions(&orig_chars, &act_chars);
-
- let mut total_precision = 0.0;
- let mut total_recall = 0.0;
- let mut total_counts = ClassificationMetrics::default();
-
- for order in 1..=CHR_F_CHAR_ORDER {
- let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order);
- let exp_ngrams = count_ngrams_from_chars(&exp_region, order);
- let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp);
-
- let orig_ngrams_for_act = count_ngrams_from_chars(&orig_for_act, order);
- let act_ngrams = count_ngrams_from_chars(&act_region, order);
- let actual_delta = compute_ngram_delta(&act_ngrams, &orig_ngrams_for_act);
-
- if expected_delta.is_empty() && actual_delta.is_empty() {
- total_precision += 1.0;
- total_recall += 1.0;
- continue;
- }
-
- let expected_counts = ngram_delta_to_counts(&expected_delta);
- let actual_counts = ngram_delta_to_counts(&actual_delta);
-
- let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
- total_precision += counts.precision();
- total_recall += counts.recall();
- total_counts.accumulate(&counts);
- }
-
- let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
- let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
- let score = if average_precision + average_recall == 0.0 {
- 0.0
- } else {
- (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
- / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
- * 100.0
- };
-
- DeltaChrFMetrics {
- score,
- beta: CHR_F_BETA,
- counts: total_counts,
- precision: average_precision,
- recall: average_recall,
- }
-}
-
-/// Reference implementation of delta-chrF metrics (original, non-optimized version).
-/// Used for testing that the optimized version produces identical results.
-#[cfg(test)]
-fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
- if original == expected && expected == actual {
- return DeltaChrFMetrics {
- score: 100.0,
- beta: CHR_F_BETA,
- precision: 1.0,
- recall: 1.0,
- ..DeltaChrFMetrics::default()
- };
- }
-
- let original_ngrams = chr_f_ngram_counts(original);
- let expected_ngrams = chr_f_ngram_counts(expected);
- let actual_ngrams = chr_f_ngram_counts(actual);
-
- let mut total_precision = 0.0;
- let mut total_recall = 0.0;
- let mut total_counts = ClassificationMetrics::default();
-
- for order in 0..CHR_F_CHAR_ORDER {
- let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
- let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
-
- if expected_delta.is_empty() && actual_delta.is_empty() {
- total_precision += 1.0;
- total_recall += 1.0;
- continue;
- }
-
- let expected_counts = ngram_delta_to_counts(&expected_delta);
- let actual_counts = ngram_delta_to_counts(&actual_delta);
-
- let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
- total_precision += counts.precision();
- total_recall += counts.recall();
- total_counts.accumulate(&counts);
- }
-
- let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
- let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
- let score = if average_precision + average_recall == 0.0 {
- 0.0
- } else {
- (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
- / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
- * 100.0
- };
-
- DeltaChrFMetrics {
- score,
- beta: CHR_F_BETA,
- counts: total_counts,
- precision: average_precision,
- recall: average_recall,
- }
-}
-
-/// Filter whitespace from a string and return as Vec<char>
-fn filter_whitespace_chars(text: &str) -> Vec<char> {
- match CHR_F_WHITESPACE {
- ChrfWhitespace::Unchanged => text.chars().collect(),
- ChrfWhitespace::Ignore => text.chars().filter(|c| !c.is_whitespace()).collect(),
- ChrfWhitespace::Collapse => collapse_whitespace(text.chars()),
- }
-}
-
-/// Collapse whitespace into single spaces.
-/// Newlines and spaces are collapsed separately.
-fn collapse_whitespace(chars: impl Iterator<Item = char>) -> Vec<char> {
- let mut result = Vec::new();
- let mut last_whitespace = None;
- for c in chars {
- if c.is_whitespace() && c != '\n' {
- if last_whitespace != Some(' ') {
- result.push(' ');
- last_whitespace = Some(' ');
- }
- } else if c == '\n' {
- if last_whitespace != Some('\n') {
- result.push(c);
- last_whitespace = Some('\n');
- }
- } else {
- result.push(c);
- last_whitespace = None;
- }
- }
- result
-}
-
-/// Extract only the changed regions between two texts, with context for n-gram boundaries.
-///
-/// Returns (original_affected_region, modified_affected_region) as Vec<char>.
-///
-/// The key insight: when computing n-gram delta between two nearly-identical texts,
-/// n-grams from unchanged regions cancel out. We only need to process:
-/// 1. The changed content itself
-/// 2. CONTEXT_CHARS (n-1) characters before and after, to capture boundary-crossing n-grams
-fn extract_changed_regions(original: &[char], modified: &[char]) -> (Vec<char>, Vec<char>) {
- // Find longest common prefix
- let prefix_len = original
- .iter()
- .zip(modified.iter())
- .take_while(|(a, b)| a == b)
- .count();
-
- // Find longest common suffix (that doesn't overlap with prefix)
- let orig_remaining = original.len().saturating_sub(prefix_len);
- let mod_remaining = modified.len().saturating_sub(prefix_len);
- let max_suffix = orig_remaining.min(mod_remaining);
-
- let suffix_len = original
- .iter()
- .rev()
- .zip(modified.iter().rev())
- .take(max_suffix)
- .take_while(|(a, b)| a == b)
- .count();
-
- // Calculate the changed region boundaries
- let orig_change_start = prefix_len;
- let orig_change_end = original.len().saturating_sub(suffix_len);
- let mod_change_start = prefix_len;
- let mod_change_end = modified.len().saturating_sub(suffix_len);
-
- // If there's no actual change, return empty regions
- if orig_change_start >= orig_change_end && mod_change_start >= mod_change_end {
- return (Vec::new(), Vec::new());
- }
-
- // Expand to include context for n-gram boundaries
- let orig_context_start = orig_change_start.saturating_sub(CONTEXT_CHARS);
- let orig_context_end = (orig_change_end + CONTEXT_CHARS).min(original.len());
- let mod_context_start = mod_change_start.saturating_sub(CONTEXT_CHARS);
- let mod_context_end = (mod_change_end + CONTEXT_CHARS).min(modified.len());
-
- let orig_region: Vec<char> = original[orig_context_start..orig_context_end].to_vec();
- let mod_region: Vec<char> = modified[mod_context_start..mod_context_end].to_vec();
-
- (orig_region, mod_region)
-}
-
-/// Count n-grams directly from a char slice (avoids String allocation for the full text)
-fn count_ngrams_from_chars(chars: &[char], n: usize) -> Counts {
- let mut counts = Counts::default();
-
- if chars.len() < n {
- return counts;
- }
-
- for window in chars.windows(n) {
- let ngram: String = window.iter().collect();
- *counts.entry(ngram).or_insert(0) += 1;
- }
-
- counts
-}
-
-#[allow(dead_code)]
-fn chr_f_ngram_counts(text: &str) -> Vec<Counts> {
- let text = match CHR_F_WHITESPACE {
- ChrfWhitespace::Unchanged => text.to_string(),
- ChrfWhitespace::Ignore => text
- .chars()
- .filter(|c| !c.is_whitespace())
- .collect::<String>(),
- ChrfWhitespace::Collapse => collapse_whitespace(text.chars())
- .into_iter()
- .collect::<String>(),
- };
-
- (1..=CHR_F_CHAR_ORDER)
- .map(|order| count_ngrams(&text, order))
- .collect()
-}
-
-fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta {
- let mut delta = CountsDelta::default();
-
- for (ngram, &before_count) in before {
- let after_count = *after.get(ngram).unwrap_or(&0);
- delta.insert(ngram.clone(), after_count as isize - before_count as isize);
- }
-
- for (ngram, &after_count) in after {
- if !before.contains_key(ngram) {
- delta.insert(ngram.clone(), after_count as isize);
- }
- }
-
- delta
-}
-
-/// Convert negative counts to special deletion tokens.
-/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1},
-/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo"
-/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive.
-fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
- let mut counts = Counts::default();
-
- for (ngram, &delta) in delta {
- if delta > 0 {
- counts.insert(ngram.clone(), delta as usize);
- } else if delta < 0 {
- counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
- }
- }
-
- counts
-}
-
-#[allow(dead_code)]
-fn count_ngrams(text: &str, n: usize) -> Counts {
- let chars: Vec<char> = text.chars().collect();
- let mut counts = Counts::default();
-
- for window in chars.windows(n) {
- let ngram: String = window.iter().collect();
- *counts.entry(ngram).or_insert(0) += 1;
- }
-
- counts
-}
-
-pub fn braces_disbalance(text: &str) -> usize {
- let mut disbalance = 0isize;
-
- let a = text.chars().filter(|&c| c == '{').count() as isize;
- let b = text.chars().filter(|&c| c == '}').count() as isize;
- disbalance += (a - b).abs();
-
- let a = text.chars().filter(|&c| c == '(').count() as isize;
- let b = text.chars().filter(|&c| c == ')').count() as isize;
- disbalance += (a - b).abs();
-
- let a = text.chars().filter(|&c| c == '[').count() as isize;
- let b = text.chars().filter(|&c| c == ']').count() as isize;
- disbalance += (a - b).abs();
-
- disbalance as usize
-}
-
-/// Extracts changed lines from a unified diff string.
-/// Returns a bag (multiset) of lines that were added (+) or removed (-).
-/// The +/- prefix is included in the line to distinguish additions from deletions.
-pub fn extract_changed_lines_from_diff(diff: &str) -> Counts {
- let mut counts = Counts::default();
-
- for line in diff.lines() {
- // Skip file headers (--- and +++)
- if line.starts_with("---") || line.starts_with("+++") {
- continue;
- }
- // Skip hunk headers (@@)
- if line.starts_with("@@") {
- continue;
- }
- // Skip diff header lines (diff --git, index, etc.)
- if line.starts_with("diff ") || line.starts_with("index ") {
- continue;
- }
- // Include added and removed lines (with their prefix)
- if line.starts_with('+') || line.starts_with('-') {
- *counts.entry(line.to_string()).or_insert(0) += 1;
- }
- }
-
- counts
-}
-
-/// Computes exact lines match metrics between expected and actual patches.
-/// Treats changed lines as a bag (multiset) - order is discarded but count matters.
-/// Returns ClassificationMetrics with TP/FP/FN counts.
-pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics {
- let expected_lines = extract_changed_lines_from_diff(expected_patch);
- let actual_lines = extract_changed_lines_from_diff(actual_patch);
- ClassificationMetrics::from_counts(&expected_lines, &actual_lines)
-}
-
-/// Returns whether the patch contains any isolated whitespace-only changes.
-///
-/// A whitespace-only change is an added or deleted line whose content is empty or
-/// contains only whitespace. It is "isolated" when it is not adjacent to any
-/// substantive (non-whitespace) change within the same contiguous change group.
pub fn has_isolated_whitespace_changes(patch_str: &str, cursor: Option<&ActualCursor>) -> bool {
- let patch = Patch::parse_unified_diff(patch_str);
-
- let cursor_new_file_line = cursor.as_ref().map(|c| (c.row + 1) as usize);
-
- for hunk in &patch.hunks {
- let lines = &hunk.lines;
- let mut new_text_line = hunk.new_start as usize;
-
- for (i, line) in lines.iter().enumerate() {
- let content = match line {
- PatchLine::Addition(s) => {
- let addition_line = new_text_line;
- new_text_line += 1;
- if s.trim().is_empty() && cursor_new_file_line == Some(addition_line) {
- continue;
- }
- s.as_str()
- }
- PatchLine::Deletion(s) => s.as_str(),
- PatchLine::Context(_) => {
- new_text_line += 1;
- continue;
- }
- _ => continue,
- };
-
- if !content.trim().is_empty() {
- continue;
- }
-
- if is_whitespace_change_isolated(lines, i) {
- return true;
- }
- }
- }
-
- false
+ edit_prediction_metrics::has_isolated_whitespace_changes(
+ patch_str,
+ cursor.map(|cursor| cursor.row),
+ )
}
-
-fn is_whitespace_change_isolated(lines: &[PatchLine], index: usize) -> bool {
- // Look backward for a non-whitespace change before hitting a context line
- for line in lines[..index].iter().rev() {
- match line {
- PatchLine::Addition(s) | PatchLine::Deletion(s) => {
- if !s.trim().is_empty() {
- return false;
- }
- }
- _ => break,
- }
- }
-
- // Look forward for a non-whitespace change before hitting a context line
- for line in &lines[index + 1..] {
- match line {
- PatchLine::Addition(s) | PatchLine::Deletion(s) => {
- if !s.trim().is_empty() {
- return false;
- }
- }
- _ => break,
- }
- }
-
- true
-}
-
-/// A simple proxy for whether the prediction respects editable region.
-pub fn is_editable_region_correct(actual_patch: &str) -> bool {
- // A typical sign of a wrong editable region: a bunch of lines deletion
- // at the beginning or end of the patch.
- let patch = Patch::parse_unified_diff(actual_patch);
- if patch.hunks.is_empty() {
- return true;
- }
-
- let hunk = &patch.hunks[0];
- let mut deletions_at_start = 0;
-
- for line in hunk.lines.iter() {
- match line {
- PatchLine::Deletion(_) => deletions_at_start += 1,
- _ => break,
- }
- }
-
- if deletions_at_start >= 3 {
- return false;
- }
-
- true
-}
-
-#[derive(Debug, Default, Clone)]
-pub struct TokenChangeCounts {
- pub inserted_tokens: usize,
- pub deleted_tokens: usize,
-}
-
-/// Counts the number of inserted and deleted tokens in a unified diff patch.
-///
-/// Tokens are words and whitespace sequences (as defined by `word_diff::tokenize`).
-/// Within each hunk, the old (`-`) and new (`+`) lines are compared at the token level
-/// using an LCS-based diff, so modified lines only count the actually changed tokens
-/// rather than the entire line.
-pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts {
- let mut counts = TokenChangeCounts::default();
- let mut old_lines: Vec<&str> = Vec::new();
- let mut new_lines: Vec<&str> = Vec::new();
-
- let flush =
- |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, counts: &mut TokenChangeCounts| {
- if old_lines.is_empty() && new_lines.is_empty() {
- return;
- }
-
- let old_text: String = old_lines
- .iter()
- .map(|line| if line.len() > 1 { &line[1..] } else { "" })
- .collect::<Vec<_>>()
- .join("\n");
-
- let new_text: String = new_lines
- .iter()
- .map(|line| if line.len() > 1 { &line[1..] } else { "" })
- .collect::<Vec<_>>()
- .join("\n");
-
- let old_tokens = tokenize(&old_text);
- let new_tokens = tokenize(&new_text);
- let ops = diff_tokens(&old_tokens, &new_tokens);
-
- for op in ops {
- match op {
- DiffOp::Equal(..) => {}
- DiffOp::Delete(start, end) => {
- counts.deleted_tokens += end - start;
- }
- DiffOp::Insert(start, end) => {
- counts.inserted_tokens += end - start;
- }
- DiffOp::Replace {
- old_start,
- old_end,
- new_start,
- new_end,
- } => {
- counts.deleted_tokens += old_end - old_start;
- counts.inserted_tokens += new_end - new_start;
- }
- }
- }
-
- old_lines.clear();
- new_lines.clear();
- };
-
- for line in patch.lines() {
- if line.starts_with("---")
- || line.starts_with("+++")
- || line.starts_with("@@")
- || line.starts_with("diff ")
- || line.starts_with("index ")
- {
- flush(&mut old_lines, &mut new_lines, &mut counts);
- } else if line.starts_with('-') {
- old_lines.push(line);
- } else if line.starts_with('+') {
- new_lines.push(line);
- } else {
- flush(&mut old_lines, &mut new_lines, &mut counts);
- }
- }
-
- flush(&mut old_lines, &mut new_lines, &mut counts);
- counts
-}
-
-#[cfg(test)]
-mod test_optimization {
- use super::*;
-
- #[test]
- fn test_extract_changed_regions_simple() {
- let original: Vec<char> = "hello world".chars().collect();
- let modified: Vec<char> = "hello there".chars().collect();
-
- let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
-
- // "world" vs "there" - with 5 chars context, we get "ello world" vs "ello there"
- // (or less if not enough chars available)
- assert!(orig_region.len() < original.len());
- assert!(mod_region.len() < modified.len());
- }
-
- #[test]
- fn test_extract_changed_regions_insertion() {
- let original: Vec<char> = "abcdef".chars().collect();
- let modified: Vec<char> = "abcXYZdef".chars().collect();
-
- let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
-
- // The insertion is between c and d, so we need context around that point
- assert!(orig_region.len() <= original.len());
- assert!(mod_region.iter().collect::<String>().contains("XYZ"));
- }
-
- #[test]
- fn test_extract_changed_regions_identical() {
- let text: Vec<char> = "identical text".chars().collect();
-
- let (orig_region, mod_region) = extract_changed_regions(&text, &text);
-
- // When texts are identical, regions should be empty
- assert!(orig_region.is_empty());
- assert!(mod_region.is_empty());
- }
-
- #[test]
- fn test_optimized_matches_original_score() {
- // Test that our optimized version produces the same results
- let test_cases = vec![
- ("hello world", "hello there", "hello world"),
- (
- "fn main() {}",
- "fn main() { println!(); }",
- "fn main() { print!(); }",
- ),
- ("abcdefghij", "abcXXXghij", "abcYYghij"),
- ("unchanged", "unchanged", "unchanged"),
- (
- "prefix middle suffix",
- "prefix CHANGED suffix",
- "prefix middle suffix",
- ),
- ];
-
- for (original, expected, actual) in test_cases {
- let score = delta_chr_f(original, expected, actual).score;
- // Just verify it produces a reasonable score (0-100)
- assert!(
- score >= 0.0 && score <= 100.0,
- "Score {} out of range for ({}, {}, {})",
- score,
- original,
- expected,
- actual
- );
- }
- }
-
- #[test]
- fn test_optimized_equals_reference() {
- // Comprehensive test that optimized version matches reference implementation exactly
- let test_cases = vec![
- // Basic cases
- ("hello world", "hello there", "hello world"),
- ("hello world", "hello there", "hello there"),
- ("unchanged", "unchanged", "unchanged"),
- // Code-like cases
- (
- "fn main() { println!(\"Hello\"); }",
- "fn main() { println!(\"Hello, World!\"); }",
- "fn main() { println!(\"Hello, World!\"); }",
- ),
- (
- "fn main() { println!(\"Hello\"); }",
- "fn main() { println!(\"Hello, World!\"); }",
- "fn main() { println!(\"Goodbye\"); }",
- ),
- // Insertion
- ("abcdef", "abcXYZdef", "abcdef"),
- ("abcdef", "abcXYZdef", "abcXYZdef"),
- ("abcdef", "abcXYZdef", "abcABCdef"),
- // Deletion
- ("abcXYZdef", "abcdef", "abcXYZdef"),
- ("abcXYZdef", "abcdef", "abcdef"),
- // Multiple changes (simulated by different expected/actual)
- ("one two three four", "one THREE four", "one two FOUR"),
- // Edge cases
- ("a", "b", "c"),
- ("", "abc", ""),
- ("abc", "", "abc"),
- // Longer text with small change
- (
- "This is a longer piece of text that contains many words and characters to process",
- "This is a longer piece of TEXT that contains many words and characters to process",
- "This is a longer piece of text that contains many words and characters to process",
- ),
- // Change at the beginning
- (
- "ORIGINAL start of text",
- "NEW start of text",
- "DIFFERENT start of text",
- ),
- // Change at the end
- (
- "text ending ORIGINAL",
- "text ending NEW",
- "text ending DIFFERENT",
- ),
- // Whitespace (should be ignored)
- ("hello world", "hello there", "hello world"),
- ("a b c d", "a X c d", "a Y c d"),
- ];
-
- for (original, expected, actual) in test_cases {
- let optimized_metrics = delta_chr_f(original, expected, actual);
- let reference_metrics = delta_chr_f_reference(original, expected, actual);
-
- assert!(
- (optimized_metrics.score - reference_metrics.score).abs() < 1e-10,
- "Score mismatch for ({:?}, {:?}, {:?}):\n optimized: {}\n reference: {}",
- original,
- expected,
- actual,
- optimized_metrics.score,
- reference_metrics.score
- );
- assert_eq!(
- optimized_metrics.counts.true_positives,
- reference_metrics.counts.true_positives
- );
- assert_eq!(
- optimized_metrics.counts.false_positives,
- reference_metrics.counts.false_positives
- );
- assert_eq!(
- optimized_metrics.counts.false_negatives,
- reference_metrics.counts.false_negatives
- );
- assert!((optimized_metrics.precision - reference_metrics.precision).abs() < 1e-10);
- assert!((optimized_metrics.recall - reference_metrics.recall).abs() < 1e-10);
- }
- }
-
- #[test]
- fn test_delta_chr_f_metrics_include_counts_and_rates() {
- let original = "one two three";
- let expected = "one three";
- let actual = "one two four";
-
- let metrics = delta_chr_f(original, expected, actual);
-
- assert!(metrics.score > 20.0 && metrics.score < 40.0);
- assert!(metrics.counts.true_positives > 0);
- assert!(metrics.counts.false_positives > 0);
- assert!(metrics.counts.false_negatives > 0);
- assert!(metrics.precision > 0.0 && metrics.precision < 1.0);
- assert!(metrics.recall > 0.0 && metrics.recall < 1.0);
- assert_eq!(metrics.beta, CHR_F_BETA);
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
- use crate::example::ActualCursor;
- use indoc::indoc;
-
- fn cursor_on_line(one_based_line: u32) -> ActualCursor {
- ActualCursor {
- path: String::new(),
- row: one_based_line - 1,
- column: 0,
- offset: 0,
- editable_region_offset: None,
- }
- }
-
- #[test]
- fn test_delta_chr_f_perfect_match() {
- let original = "fn main() { println!(\"Hello\");}";
- let expected = "fn main() { println!(\"Hello, World!\");}";
-
- let score = delta_chr_f(original, expected, expected).score;
- assert!((score - 100.0).abs() < 1e-2);
- }
-
- #[test]
- fn test_delta_chr_f_wrong_edit() {
- // When the edit is wrong
- let original = "one two three";
- let expected = "one three"; // deleted "two "
- let actual = "one two four"; // deleted "three", added "four"
-
- // Then the score should be low
- let score = delta_chr_f(original, expected, actual).score;
- assert!(score > 20.0 && score < 40.0);
- }
-
- #[test]
- fn test_delta_chr_f_partial_match() {
- let original = "let x = 42;";
- let expected = "let x = 100;";
- let actual = "let x = 99;";
-
- // We got the edit location right, but the replacement text is wrong.
- // Deleted ngrams will match, bringing the score somewhere in the middle.
- let score = delta_chr_f(original, expected, actual).score;
- assert!(score > 40.0 && score < 60.0);
- }
-
- #[test]
- fn test_delta_chr_f_missed_edit() {
- // When predictions makes no changes
- let original = "prefix old suffix";
- let expected = "prefix new suffix";
- let actual = "prefix old suffix"; // no change
-
- // Then the score should be low (all expected changes are false negatives)
- let score = delta_chr_f(original, expected, actual).score;
- assert!(score < 20.0);
- }
-
- #[test]
- fn test_delta_chr_f_extra_edit() {
- // When adding unexpected content
- let original = "helloworld";
- let expected = "helloworld"; // no change expected
- let actual = "helloextraworld"; // added "extra"
-
- // Then the score should be low (all actual changes are false positives)
- let score = delta_chr_f(original, expected, actual).score;
- assert!(score < 20.0);
- }
-
- #[test]
- fn test_delta_chr_f_no_changes() {
- let text = "unchanged text";
- let score = delta_chr_f(text, text, text).score;
- assert!((score - 100.0).abs() < 1e-2);
- }
-
- #[test]
- fn test_braces_disbalance() {
- let text = "let x = { 1 + 2 };";
- assert_eq!(braces_disbalance(text), 0);
-
- let text = "let x = { 1 + 2";
- assert_eq!(braces_disbalance(text), 1);
-
- let text = "let x = { 1 + 2 )";
- assert_eq!(braces_disbalance(text), 2);
- }
-
- #[test]
- fn test_extract_changed_lines_from_diff() {
- let diff = r#"--- a/file.rs
-+++ b/file.rs
-@@ -1,3 +1,3 @@
- fn main() {
-- println!("hello");
-+ println!("world");
- }"#;
-
- let counts = extract_changed_lines_from_diff(diff);
- assert_eq!(counts.get("- println!(\"hello\");"), Some(&1));
- assert_eq!(counts.get("+ println!(\"world\");"), Some(&1));
- assert_eq!(counts.len(), 2);
- }
-
- #[test]
- fn test_extract_changed_lines_skips_headers() {
- let diff = r#"diff --git a/file.rs b/file.rs
-index abc123..def456 100644
---- a/file.rs
-+++ b/file.rs
-@@ -1,2 +1,2 @@
--old line
-+new line"#;
-
- let counts = extract_changed_lines_from_diff(diff);
- assert_eq!(counts.get("-old line"), Some(&1));
- assert_eq!(counts.get("+new line"), Some(&1));
- assert_eq!(counts.len(), 2);
- }
-
- #[test]
- fn test_exact_lines_match_perfect() {
- let expected = r#"--- a/file.rs
-+++ b/file.rs
-@@ -1,3 +1,3 @@
--old line 1
--old line 2
-+new line 1
-+new line 2"#;
-
- let actual = r#"--- a/file.rs
-+++ b/file.rs
-@@ -1,3 +1,3 @@
--old line 1
--old line 2
-+new line 1
-+new line 2"#;
-
- let metrics = exact_lines_match(expected, actual);
- assert_eq!(metrics.true_positives, 4);
- assert_eq!(metrics.false_positives, 0);
- assert_eq!(metrics.false_negatives, 0);
- assert!((metrics.precision() - 1.0).abs() < 1e-6);
- assert!((metrics.recall() - 1.0).abs() < 1e-6);
- assert!((metrics.f1() - 1.0).abs() < 1e-6);
- }
-
- #[test]
- fn test_exact_lines_match_partial() {
- let expected = r#"-old line 1
--old line 2
-+new line 1
-+new line 2"#;
-
- let actual = r#"-old line 1
-+new line 1
-+extra line"#;
-
- let metrics = exact_lines_match(expected, actual);
- // TP: "-old line 1" and "+new line 1" (2)
- // FP: "+extra line" (1)
- // FN: "-old line 2" and "+new line 2" (2)
- assert_eq!(metrics.true_positives, 2);
- assert_eq!(metrics.false_positives, 1);
- assert_eq!(metrics.false_negatives, 2);
- }
-
- #[test]
- fn test_exact_lines_match_no_overlap() {
- let expected = r#"-line a
-+line b"#;
-
- let actual = r#"-line x
-+line y"#;
-
- let metrics = exact_lines_match(expected, actual);
- assert_eq!(metrics.true_positives, 0);
- assert_eq!(metrics.false_positives, 2);
- assert_eq!(metrics.false_negatives, 2);
- assert!((metrics.precision()).abs() < 1e-6);
- assert!((metrics.recall()).abs() < 1e-6);
- }
-
- #[test]
- fn test_exact_lines_match_duplicate_lines() {
- let expected = r#"+line a
-+line a
-+line a"#;
-
- let actual = r#"+line a
-+line a"#;
-
- let metrics = exact_lines_match(expected, actual);
- // Expected has 3 "+line a", actual has 2
- // TP: 2, FN: 1, FP: 0
- assert_eq!(metrics.true_positives, 2);
- assert_eq!(metrics.false_positives, 0);
- assert_eq!(metrics.false_negatives, 1);
- }
-
- #[test]
- fn test_exact_lines_match_empty_patches() {
- let metrics = exact_lines_match("", "");
- assert_eq!(metrics.true_positives, 0);
- assert_eq!(metrics.false_positives, 0);
- assert_eq!(metrics.false_negatives, 0);
- }
-
- #[test]
- fn test_is_editable_region_correct() {
- let patch = indoc! {"
- @@ -1,1 +1,1 @@
- -context
- -removed
- -from the beginning of the file
- import sys
- +sys.exit(0)
-
- "};
- assert!(!is_editable_region_correct(patch));
-
- let patch = indoc! {"
- @@ -1,1 +1,1 @@
- "};
- assert!(is_editable_region_correct(patch));
- }
-
- #[test]
- fn test_isolated_whitespace_purely_whitespace_patch() {
- let patch = indoc! {"
- @@ -1,3 +1,4 @@
- fn main() {
- +
- println!(\"hello\");
- }
- "};
- assert!(has_isolated_whitespace_changes(patch, None));
- }
-
- #[test]
- fn test_isolated_whitespace_adjacent_to_real_change() {
- let patch = indoc! {"
- @@ -1,3 +1,4 @@
- fn main() {
- +
- + let x = 1;
- println!(\"hello\");
- }
- "};
- assert!(!has_isolated_whitespace_changes(patch, None));
- }
-
- #[test]
- fn test_isolated_whitespace_no_whitespace_changes() {
- let patch = indoc! {"
- @@ -1,3 +1,3 @@
- fn main() {
- - println!(\"hello\");
- + println!(\"world\");
- }
- "};
- assert!(!has_isolated_whitespace_changes(patch, None));
- }
-
- #[test]
- fn test_isolated_whitespace_deletion() {
- let patch = indoc! {"
- @@ -1,4 +1,3 @@
- fn main() {
- -
- println!(\"hello\");
- }
- "};
- assert!(has_isolated_whitespace_changes(patch, None));
- }
-
- #[test]
- fn test_isolated_whitespace_mixed_groups() {
- let patch = indoc! {"
- @@ -1,7 +1,8 @@
- fn main() {
- +
- let x = 1;
- - let y = 2;
- + let y = 3;
-
- +
- println!(\"hello\");
- }
- "};
- assert!(has_isolated_whitespace_changes(patch, None));
- }
-
- #[test]
- fn test_isolated_whitespace_empty_patch() {
- let patch = "";
- assert!(!has_isolated_whitespace_changes(patch, None));
- }
-
- #[test]
- fn test_isolated_whitespace_skipped_on_cursor_line() {
- // The addition of a blank line at new-file line 2 should be skipped
- // because the cursor is on that line.
- let patch = indoc! {"
- @@ -1,3 +1,4 @@
- fn main() {
- +
- println!(\"hello\");
- }
- "};
- // New-file line 2 is the added blank line
- let cursor = cursor_on_line(2);
- assert!(!has_isolated_whitespace_changes(patch, Some(&cursor)));
- }
-
- #[test]
- fn test_isolated_whitespace_not_skipped_when_cursor_on_different_line() {
- // The blank line is at new-file line 2, but the cursor is on line 1.
- let patch = indoc! {"
- @@ -1,3 +1,4 @@
- fn main() {
- +
- println!(\"hello\");
- }
- "};
- let cursor = cursor_on_line(1);
- assert!(has_isolated_whitespace_changes(patch, Some(&cursor)));
- }
-
- #[test]
- fn test_isolated_whitespace_deletion_not_skipped_by_cursor() {
- // Deletions don't have a new-file line, so cursor can't suppress them.
- let patch = indoc! {"
- @@ -1,4 +1,3 @@
- fn main() {
- -
- println!(\"hello\");
- }
- "};
- let cursor = cursor_on_line(2);
- assert!(has_isolated_whitespace_changes(patch, Some(&cursor)));
- }
-
- #[test]
- fn test_count_patch_token_changes_real_world_rename() {
- // Real-world patch that was reported as returning 0 tokens
- let patch = "--- a/sip_call\\README.md\n+++ b/sip_call\\README.md\n@@ -1,1 +1,1 @@\n-# \n+# SIP Call\n";
- let counts = count_patch_token_changes(patch);
- // "# " vs "# SIP Call" — the "SIP" and "Call" tokens (and a whitespace token) are inserted
- assert!(
- counts.inserted_tokens > 0,
- "expected inserted tokens > 0, got {}",
- counts.inserted_tokens
- );
- assert_eq!(counts.deleted_tokens, 0);
- }
-
- #[test]
- fn test_count_patch_token_changes_real_world_expansion() {
- // Real-world patch: single token expanded to multiple lines
- let patch = "--- a/task1/src/app/app.html\n+++ b/task1/src/app/app.html\n@@ -1,7 +1,9 @@\n <style>\n- m\n+ main {\n+ \n+ }\n </style>\n \n <main>\n \n </main>\n";
- let counts = count_patch_token_changes(patch);
- assert!(
- counts.inserted_tokens > 0,
- "expected inserted tokens > 0, got {}",
- counts.inserted_tokens
- );
- assert!(
- counts.deleted_tokens > 0,
- "expected deleted tokens > 0, got {}",
- counts.deleted_tokens
- );
- }
-
- #[test]
- fn test_count_patch_token_changes_simple_replacement() {
- let patch = indoc! {"
- @@ -1,3 +1,3 @@
- fn main() {
- - println!(\"hello\");
- + println!(\"world\");
- }
- "};
- let counts = count_patch_token_changes(patch);
- assert_eq!(counts.deleted_tokens, 1, "deleted: \"hello\"");
- assert_eq!(counts.inserted_tokens, 1, "inserted: \"world\"");
- }
-
- #[test]
- fn test_count_patch_token_changes_insertion_only() {
- let patch = indoc! {"
- @@ -1,2 +1,3 @@
- fn main() {
- + println!(\"hello\");
- }
- "};
- let counts = count_patch_token_changes(patch);
- assert_eq!(counts.deleted_tokens, 0);
- assert!(counts.inserted_tokens > 0);
- }
-
- #[test]
- fn test_count_patch_token_changes_deletion_only() {
- let patch = indoc! {"
- @@ -1,3 +1,2 @@
- fn main() {
- - println!(\"hello\");
- }
- "};
- let counts = count_patch_token_changes(patch);
- assert!(counts.deleted_tokens > 0);
- assert_eq!(counts.inserted_tokens, 0);
- }
-
- #[test]
- fn test_count_patch_token_changes_empty_patch() {
- let patch = "";
- let counts = count_patch_token_changes(patch);
- assert_eq!(counts.deleted_tokens, 0);
- assert_eq!(counts.inserted_tokens, 0);
- }
-
- #[test]
- fn test_count_patch_token_changes_multiple_hunks() {
- let patch = indoc! {"
- @@ -1,3 +1,3 @@
- fn main() {
- - let x = 1;
- + let x = 2;
- }
- @@ -10,3 +10,3 @@
- fn other() {
- - let y = 3;
- + let y = 4;
- }
- "};
- let counts = count_patch_token_changes(patch);
- assert_eq!(counts.deleted_tokens, 2, "deleted: \"1\" and \"3\"");
- assert_eq!(counts.inserted_tokens, 2, "inserted: \"2\" and \"4\"");
- }
-
- #[test]
- fn test_count_patch_token_changes_multiword_change() {
- let patch = indoc! {"
- @@ -1,1 +1,1 @@
- -hello world foo
- +hello bar baz
- "};
- let counts = count_patch_token_changes(patch);
- // "world" and "foo" deleted, "bar" and "baz" inserted
- // (whitespace tokens between them may also count)
- assert!(counts.deleted_tokens >= 2);
- assert!(counts.inserted_tokens >= 2);
- }
-
- #[test]
- fn test_whitespace_collapse() {
- let text = "abc \n\n\n 123";
- let collapsed = collapse_whitespace(text.chars());
- assert_eq!(
- collapsed,
- vec!['a', 'b', 'c', ' ', '\n', ' ', '1', '2', '3']
- );
- }
-}
-
-pub use edit_prediction::metrics::compute_kept_rate;
@@ -1,2016 +1,17 @@
-use std::ops::Range;
use std::path::Path;
-use std::sync::Arc;
-
-use language::{char_diff, text_diff};
-use zeta_prompt::udiff::apply_diff_to_string;
use zeta_prompt::ZetaPromptInput;
-fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String {
- let hunks = parse_diff_hunks(diff_str);
- let mut result = text.to_string();
-
- for hunk in hunks {
- let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk));
- if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) {
- result = updated;
- }
- }
-
- result
-}
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-struct ParsedHunk {
- old_start: u32,
- old_count: u32,
- new_start: u32,
- new_count: u32,
- lines: Vec<HunkLine>,
-}
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-enum HunkLine {
- Context(String),
- Addition(String),
- Deletion(String),
-}
-
-fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
- let line = line.strip_prefix("@@ -")?;
- let (old_part, rest) = line.split_once(' ')?;
- let rest = rest.strip_prefix('+')?;
- let (new_part, _) = rest.split_once(" @@")?;
-
- let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') {
- (start.parse().ok()?, count.parse().ok()?)
- } else {
- (old_part.parse().ok()?, 1)
- };
-
- let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') {
- (start.parse().ok()?, count.parse().ok()?)
- } else {
- (new_part.parse().ok()?, 1)
- };
-
- Some((old_start, old_count, new_start, new_count))
-}
-
-fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
- let mut hunks = Vec::new();
- let mut current_hunk: Option<ParsedHunk> = None;
-
- for line in diff.lines() {
- if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
- if let Some(hunk) = current_hunk.take() {
- hunks.push(hunk);
- }
- current_hunk = Some(ParsedHunk {
- old_start,
- old_count,
- new_start,
- new_count,
- lines: Vec::new(),
- });
- } else if let Some(ref mut hunk) = current_hunk {
- if let Some(stripped) = line.strip_prefix('+') {
- hunk.lines.push(HunkLine::Addition(stripped.to_string()));
- } else if let Some(stripped) = line.strip_prefix('-') {
- hunk.lines.push(HunkLine::Deletion(stripped.to_string()));
- } else if let Some(stripped) = line.strip_prefix(' ') {
- hunk.lines.push(HunkLine::Context(stripped.to_string()));
- } else if line.is_empty() {
- hunk.lines.push(HunkLine::Context(String::new()));
- }
- }
- }
-
- if let Some(hunk) = current_hunk {
- hunks.push(hunk);
- }
-
- hunks
-}
-
-fn format_hunk(hunk: &ParsedHunk) -> String {
- let mut result = format!(
- "@@ -{},{} +{},{} @@\n",
- hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count
- );
- for line in &hunk.lines {
- match line {
- HunkLine::Context(text) => {
- result.push(' ');
- result.push_str(text);
- result.push('\n');
- }
- HunkLine::Addition(text) => {
- result.push('+');
- result.push_str(text);
- result.push('\n');
- }
- HunkLine::Deletion(text) => {
- result.push('-');
- result.push_str(text);
- result.push('\n');
- }
- }
- }
- result
-}
-
-fn filter_diff_hunks_by_excerpt(
- diff: &str,
- excerpt_start_row: u32,
- excerpt_row_count: u32,
-) -> (String, i32) {
- let hunks = parse_diff_hunks(diff);
- let excerpt_start_0based = excerpt_start_row;
- let excerpt_end_0based = excerpt_start_row + excerpt_row_count;
-
- let mut filtered_hunks = Vec::new();
- let mut cumulative_line_offset: i32 = 0;
-
- for hunk in hunks {
- let hunk_start_0based = hunk.new_start.saturating_sub(1);
- let hunk_end_0based = hunk_start_0based + hunk.new_count;
-
- let additions: i32 = hunk
- .lines
- .iter()
- .filter(|l| matches!(l, HunkLine::Addition(_)))
- .count() as i32;
- let deletions: i32 = hunk
- .lines
- .iter()
- .filter(|l| matches!(l, HunkLine::Deletion(_)))
- .count() as i32;
- let hunk_line_delta = additions - deletions;
-
- if hunk_end_0based <= excerpt_start_0based {
- cumulative_line_offset += hunk_line_delta;
- continue;
- }
-
- if hunk_start_0based >= excerpt_end_0based {
- continue;
- }
-
- let mut filtered_lines = Vec::new();
- let mut current_row_0based = hunk_start_0based;
- let mut filtered_old_count = 0u32;
- let mut filtered_new_count = 0u32;
- let mut first_included_row: Option<u32> = None;
-
- for line in &hunk.lines {
- match line {
- HunkLine::Context(text) => {
- if current_row_0based >= excerpt_start_0based
- && current_row_0based < excerpt_end_0based
- {
- if first_included_row.is_none() {
- first_included_row = Some(current_row_0based);
- }
- filtered_lines.push(HunkLine::Context(text.clone()));
- filtered_old_count += 1;
- filtered_new_count += 1;
- }
- current_row_0based += 1;
- }
- HunkLine::Addition(text) => {
- if current_row_0based >= excerpt_start_0based
- && current_row_0based < excerpt_end_0based
- {
- if first_included_row.is_none() {
- first_included_row = Some(current_row_0based);
- }
- filtered_lines.push(HunkLine::Addition(text.clone()));
- filtered_new_count += 1;
- }
- current_row_0based += 1;
- }
- HunkLine::Deletion(text) => {
- if current_row_0based >= excerpt_start_0based
- && current_row_0based < excerpt_end_0based
- {
- if first_included_row.is_none() {
- first_included_row = Some(current_row_0based);
- }
- filtered_lines.push(HunkLine::Deletion(text.clone()));
- filtered_old_count += 1;
- }
- }
- }
- }
-
- if !filtered_lines.is_empty() {
- let first_row = first_included_row.unwrap_or(excerpt_start_0based);
- let new_start_1based = (first_row - excerpt_start_0based) + 1;
-
- filtered_hunks.push(ParsedHunk {
- old_start: new_start_1based,
- old_count: filtered_old_count,
- new_start: new_start_1based,
- new_count: filtered_new_count,
- lines: filtered_lines,
- });
- }
-
- cumulative_line_offset += hunk_line_delta;
- }
-
- let mut result = String::new();
- for hunk in &filtered_hunks {
- result.push_str(&format_hunk(hunk));
- }
-
- (result, cumulative_line_offset)
-}
-
-fn compute_excerpt_aware_reversal_overlap(
- edit_history_diffs: &[&str],
- excerpt_content: &str,
- excerpt_start_row: u32,
- predicted_content: &str,
-) -> ReversalOverlap {
- let mut current_content = excerpt_content.to_string();
- let mut current_excerpt_start_row = excerpt_start_row;
-
- for diff in edit_history_diffs.iter().rev() {
- if diff.is_empty() {
- continue;
- }
-
- let current_row_count = current_content.lines().count() as u32;
- let (filtered_diff, _line_offset) =
- filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1));
-
- if filtered_diff.is_empty() {
- let hunks = parse_diff_hunks(diff);
- for hunk in hunks {
- let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
- if hunk_end <= current_excerpt_start_row {
- let additions: u32 = hunk
- .lines
- .iter()
- .filter(|l| matches!(l, HunkLine::Addition(_)))
- .count() as u32;
- let deletions: u32 = hunk
- .lines
- .iter()
- .filter(|l| matches!(l, HunkLine::Deletion(_)))
- .count() as u32;
- if additions >= deletions {
- current_excerpt_start_row =
- current_excerpt_start_row.saturating_sub(additions - deletions);
- } else {
- current_excerpt_start_row += deletions - additions;
- }
- }
- }
- continue;
- }
-
- let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff));
- match apply_diff_to_string(&reversed, ¤t_content) {
- Ok(updated) => {
- current_content = updated;
- }
- Err(_) => {
- continue;
- }
- }
-
- let hunks = parse_diff_hunks(diff);
- for hunk in hunks {
- let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
- if hunk_end <= current_excerpt_start_row {
- let additions: u32 = hunk
- .lines
- .iter()
- .filter(|l| matches!(l, HunkLine::Addition(_)))
- .count() as u32;
- let deletions: u32 = hunk
- .lines
- .iter()
- .filter(|l| matches!(l, HunkLine::Deletion(_)))
- .count() as u32;
- if additions >= deletions {
- current_excerpt_start_row =
- current_excerpt_start_row.saturating_sub(additions - deletions);
- } else {
- current_excerpt_start_row += deletions - additions;
- }
- }
- }
- }
-
- compute_reversal_overlap(¤t_content, excerpt_content, predicted_content)
-}
-
-fn reverse_diff(diff: &str) -> String {
- let mut result: String = diff
- .lines()
- .map(|line| {
- if line.starts_with("--- ") {
- line.replacen("--- ", "+++ ", 1)
- } else if line.starts_with("+++ ") {
- line.replacen("+++ ", "--- ", 1)
- } else if line.starts_with('+') && !line.starts_with("+++") {
- format!("-{}", &line[1..])
- } else if line.starts_with('-') && !line.starts_with("---") {
- format!("+{}", &line[1..])
- } else {
- line.to_string()
- }
- })
- .collect::<Vec<_>>()
- .join("\n");
- if diff.ends_with('\n') {
- result.push('\n');
- }
- result
-}
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-struct GranularEdit {
- range: Range<usize>,
- old_text: String,
- new_text: String,
-}
-
-fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
- text_diff(old_text, new_text)
- .into_iter()
- .map(|(range, new_text)| GranularEdit {
- old_text: old_text[range.clone()].to_string(),
- range,
- new_text: new_text.to_string(),
- })
- .collect()
-}
-
-#[derive(Debug, Clone)]
-struct HistoryAdditionRange {
- range_in_current: Range<usize>,
-}
-
-#[derive(Debug, Clone)]
-struct HistoryDeletionRange {
- deleted_text: String,
- position_in_current: usize,
-}
-
-fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
- let mut result = Vec::new();
- let mut offset_delta: isize = 0;
-
- for edit in history_edits {
- if !edit.new_text.is_empty() {
- let new_start = (edit.range.start as isize + offset_delta) as usize;
- let new_end = new_start + edit.new_text.len();
- result.push(HistoryAdditionRange {
- range_in_current: new_start..new_end,
- });
- }
-
- offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
- }
-
- result
-}
-
-fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
- let mut result = Vec::new();
- let mut offset_delta: isize = 0;
-
- for edit in history_edits {
- if !edit.old_text.is_empty() {
- let position_in_current = (edit.range.start as isize + offset_delta) as usize;
- result.push(HistoryDeletionRange {
- deleted_text: edit.old_text.clone(),
- position_in_current,
- });
- }
-
- offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
- }
-
- result
-}
-
-#[derive(Debug, Clone, Default, PartialEq, Eq)]
-struct ReversalOverlap {
- chars_reversing_user_edits: usize,
- total_chars_in_prediction: usize,
-}
-
-impl ReversalOverlap {
- fn ratio(&self) -> f32 {
- if self.total_chars_in_prediction == 0 {
- 0.0
- } else {
- self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
- }
- }
-}
-
-/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
-/// or where `new_text` appears as a subsequence within `old_text` (reduction).
-///
-/// For extensions: when the user's text is preserved (in order) within the prediction,
-/// we only count the newly inserted characters, not the preserved ones.
-/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
-/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
-///
-/// For reductions: when the prediction's text is preserved (in order) within the original,
-/// we only count the deleted characters, not the preserved ones.
-/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
-fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
- edits
- .into_iter()
- .flat_map(|edit| {
- if edit.old_text.is_empty() || edit.new_text.is_empty() {
- return vec![edit];
- }
-
- // Use character-wise diff to find exact byte ranges of changes
- let char_edits = char_diff(&edit.old_text, &edit.new_text);
-
- let all_deletions = !char_edits.is_empty()
- && char_edits
- .iter()
- .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
- let all_insertions = !char_edits.is_empty()
- && char_edits
- .iter()
- .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
- if all_deletions || all_insertions {
- return char_edits
- .into_iter()
- .map(|(range, replacement)| GranularEdit {
- range: edit.range.start + range.start..edit.range.start + range.end,
- old_text: edit.old_text[range].to_string(),
- new_text: replacement.to_string(),
- })
- .collect();
- }
-
- // Otherwise, keep the original edit (mixed changes)
- vec![edit]
- })
- .collect()
-}
-
-fn compute_reversal_overlap(
- original_content: &str,
- current_content: &str,
- predicted_content: &str,
-) -> ReversalOverlap {
- let history_edits =
- normalize_extension_edits(compute_granular_edits(original_content, current_content));
- let prediction_edits =
- normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
-
- let history_addition_ranges = compute_history_addition_ranges(&history_edits);
- let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
-
- let reversed_additions =
- compute_reversed_additions(&history_addition_ranges, &prediction_edits);
- let restored_deletions =
- compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
-
- let total_chars_in_prediction: usize = prediction_edits
- .iter()
- .map(|e| e.new_text.chars().count() + e.old_text.chars().count())
- .sum();
-
- ReversalOverlap {
- chars_reversing_user_edits: reversed_additions + restored_deletions,
- total_chars_in_prediction,
- }
-}
-
-fn compute_reversed_additions(
- history_addition_ranges: &[HistoryAdditionRange],
- prediction_edits: &[GranularEdit],
-) -> usize {
- let mut reversed_chars = 0;
-
- for pred_edit in prediction_edits {
- for history_addition in history_addition_ranges {
- let overlap_start = pred_edit
- .range
- .start
- .max(history_addition.range_in_current.start);
- let overlap_end = pred_edit
- .range
- .end
- .min(history_addition.range_in_current.end);
-
- if overlap_start < overlap_end {
- let relative_start = overlap_start - pred_edit.range.start;
- let relative_end = overlap_end - pred_edit.range.start;
- let overlap_text = &pred_edit.old_text[relative_start..relative_end];
- reversed_chars += overlap_text.chars().count();
- }
- }
- }
-
- reversed_chars
-}
-
-fn compute_restored_deletions(
- history_deletion_ranges: &[HistoryDeletionRange],
- prediction_edits: &[GranularEdit],
-) -> usize {
- let mut restored = 0;
-
- for pred_edit in prediction_edits {
- if pred_edit.new_text.is_empty() {
- continue;
- }
-
- for deletion in history_deletion_ranges {
- if pred_edit.range.contains(&deletion.position_in_current)
- || deletion.position_in_current == pred_edit.range.start
- {
- restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text);
- }
- }
- }
-
- restored
-}
-
-fn compute_lcs_length(a: &str, b: &str) -> usize {
- let a_chars: Vec<char> = a.chars().collect();
- let b_chars: Vec<char> = b.chars().collect();
- let m = a_chars.len();
- let n = b_chars.len();
-
- if m == 0 || n == 0 {
- return 0;
- }
-
- let mut prev = vec![0; n + 1];
- let mut curr = vec![0; n + 1];
-
- for i in 1..=m {
- for j in 1..=n {
- if a_chars[i - 1] == b_chars[j - 1] {
- curr[j] = prev[j - 1] + 1;
- } else {
- curr[j] = prev[j].max(curr[j - 1]);
- }
- }
- std::mem::swap(&mut prev, &mut curr);
- curr.fill(0);
- }
-
- prev[n]
-}
-
-fn filter_edit_history_by_path<'a>(
- edit_history: &'a [Arc<zeta_prompt::Event>],
- cursor_path: &std::path::Path,
-) -> Vec<&'a zeta_prompt::Event> {
- edit_history
- .iter()
- .filter(|event| match event.as_ref() {
- zeta_prompt::Event::BufferChange { path, .. } => {
- let event_path = path.as_ref();
- if event_path == cursor_path {
- return true;
- }
- let stripped = event_path
- .components()
- .skip(1)
- .collect::<std::path::PathBuf>();
- stripped == cursor_path
- }
- })
- .map(|arc| arc.as_ref())
- .collect()
-}
-
-fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
- match event {
- zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
- }
-}
-
-fn is_predicted_event(event: &zeta_prompt::Event) -> bool {
- match event {
- zeta_prompt::Event::BufferChange { predicted, .. } => *predicted,
- }
-}
-
pub fn compute_prediction_reversal_ratio(
prompt_inputs: &ZetaPromptInput,
predicted_content: &str,
cursor_path: &Path,
) -> f32 {
- let current_content: &str = prompt_inputs.cursor_excerpt.as_ref();
-
- let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.events;
- let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
-
- let most_recent = match relevant_events.last() {
- Some(event) if !is_predicted_event(event) => *event,
- _ => return 0.0,
- };
-
- let diff = extract_diff_from_event(most_recent);
- if diff.is_empty() {
- return 0.0;
- }
-
- if let Some(excerpt_start_row) = prompt_inputs.excerpt_start_row {
- let diffs = vec![diff];
- let overlap = compute_excerpt_aware_reversal_overlap(
- &diffs,
- current_content,
- excerpt_start_row,
- predicted_content,
- );
- return overlap.ratio();
- }
-
- let reversed = reverse_diff(diff);
- let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
- let original_content = match apply_diff_to_string(&with_headers, current_content) {
- Ok(updated_content) => updated_content,
- Err(_) => apply_diff_to_string_lenient(&reversed, current_content),
- };
-
- let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
- overlap.ratio()
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use indoc::indoc;
- use zeta_prompt::ExcerptRanges;
- use zeta_prompt::udiff::apply_diff_to_string;
-
- fn make_test_prompt_inputs(
- content: &str,
- events: Vec<Arc<zeta_prompt::Event>>,
- excerpt_start_row: Option<u32>,
- ) -> ZetaPromptInput {
- ZetaPromptInput {
- cursor_path: Arc::from(Path::new("src/test.rs")),
- cursor_excerpt: content.into(),
- cursor_offset_in_excerpt: 0,
- excerpt_start_row,
- events,
- related_files: Some(Vec::new()),
- active_buffer_diagnostics: Vec::new(),
- excerpt_ranges: ExcerptRanges {
- editable_150: 0..content.len(),
- editable_180: 0..content.len(),
- editable_350: 0..content.len(),
- editable_150_context_350: 0..content.len(),
- editable_180_context_350: 0..content.len(),
- editable_350_context_150: 0..content.len(),
- ..Default::default()
- },
- syntax_ranges: None,
- experiment: None,
- in_open_source_repo: false,
- can_collect_data: false,
- repo_url: None,
- }
- }
-
- #[test]
- fn test_reversal_overlap() {
- struct Case {
- name: &'static str,
- original: &'static str,
- current: &'static str,
- predicted: &'static str,
- expected_reversal_chars: usize,
- expected_total_chars: usize,
- }
-
- let cases = [
- Case {
- name: "user_adds_line_prediction_removes_it",
- original: indoc! {"
- a
- b
- c"},
- current: indoc! {"
- a
- new line
- b
- c"},
- predicted: indoc! {"
- a
- b
- c"},
- expected_reversal_chars: 9,
- expected_total_chars: 9,
- },
- Case {
- name: "user_deletes_line_prediction_restores_it",
- original: indoc! {"
- a
- deleted
- b"},
- current: indoc! {"
- a
- b"},
- predicted: indoc! {"
- a
- deleted
- b"},
- expected_reversal_chars: 8,
- expected_total_chars: 8,
- },
- Case {
- name: "user_deletes_text_prediction_restores_partial",
- original: "hello beautiful world",
- current: "hello world",
- predicted: "hello beautiful world",
- expected_reversal_chars: 10,
- expected_total_chars: 10,
- },
- Case {
- name: "user_deletes_foo_prediction_adds_bar",
- original: "foo",
- current: "",
- predicted: "bar",
- expected_reversal_chars: 0,
- expected_total_chars: 3,
- },
- Case {
- name: "independent_edits_different_locations",
- original: indoc! {"
- line1
- line2
- line3"},
- current: indoc! {"
- LINE1
- line2
- line3"},
- predicted: indoc! {"
- LINE1
- line2
- LINE3"},
- expected_reversal_chars: 0,
- expected_total_chars: 10,
- },
- Case {
- name: "no_history_edits",
- original: "same",
- current: "same",
- predicted: "different",
- expected_reversal_chars: 0,
- expected_total_chars: 13,
- },
- Case {
- name: "user_replaces_text_prediction_reverses",
- original: indoc! {"
- keep
- delete_me
- keep2"},
- current: indoc! {"
- keep
- added
- keep2"},
- predicted: indoc! {"
- keep
- delete_me
- keep2"},
- expected_reversal_chars: 14,
- expected_total_chars: 14,
- },
- Case {
- name: "user_modifies_word_prediction_modifies_differently",
- original: "the quick brown fox",
- current: "the slow brown fox",
- predicted: "the fast brown fox",
- expected_reversal_chars: 4,
- expected_total_chars: 8,
- },
- Case {
- name: "user finishes function name (suffix)",
- original: "",
- current: "epr",
- predicted: "eprintln!()",
- expected_reversal_chars: 0,
- expected_total_chars: 8,
- },
- Case {
- name: "user starts function name (prefix)",
- original: "",
- current: "my_function()",
- predicted: "test_my_function()",
- expected_reversal_chars: 0,
- expected_total_chars: 5,
- },
- Case {
- name: "user types partial, prediction extends in multiple places",
- original: "",
- current: "test_my_function",
- predicted: "a_test_for_my_special_function_plz",
- expected_reversal_chars: 0,
- expected_total_chars: 18,
- },
- // Edge cases for subsequence matching
- Case {
- name: "subsequence with interleaved underscores",
- original: "",
- current: "a_b_c",
- predicted: "_a__b__c__",
- expected_reversal_chars: 0,
- expected_total_chars: 5,
- },
- Case {
- name: "not a subsequence - different characters",
- original: "",
- current: "abc",
- predicted: "xyz",
- expected_reversal_chars: 3,
- expected_total_chars: 6,
- },
- Case {
- name: "not a subsequence - wrong order",
- original: "",
- current: "abc",
- predicted: "cba",
- expected_reversal_chars: 3,
- expected_total_chars: 6,
- },
- Case {
- name: "partial subsequence - only some chars match",
- original: "",
- current: "abcd",
- predicted: "axbx",
- expected_reversal_chars: 4,
- expected_total_chars: 8,
- },
- // Common completion patterns
- Case {
- name: "completing a method call",
- original: "",
- current: "vec.pu",
- predicted: "vec.push(item)",
- expected_reversal_chars: 0,
- expected_total_chars: 8,
- },
- Case {
- name: "completing an import statement",
- original: "",
- current: "use std::col",
- predicted: "use std::collections::HashMap",
- expected_reversal_chars: 0,
- expected_total_chars: 17,
- },
- Case {
- name: "completing a struct field",
- original: "",
- current: "name: St",
- predicted: "name: String",
- expected_reversal_chars: 0,
- expected_total_chars: 4,
- },
- Case {
- name: "prediction replaces with completely different text",
- original: "",
- current: "hello",
- predicted: "world",
- expected_reversal_chars: 5,
- expected_total_chars: 10,
- },
- Case {
- name: "empty prediction removes user text",
- original: "",
- current: "mistake",
- predicted: "",
- expected_reversal_chars: 7,
- expected_total_chars: 7,
- },
- Case {
- name: "fixing typo is not reversal",
- original: "",
- current: "<dv",
- predicted: "<div>",
- expected_reversal_chars: 0,
- expected_total_chars: 2,
- },
- Case {
- name: "infix insertion not reversal",
- original: indoc! {"
- from my_project import Foo
- "},
- current: indoc! {"
- ifrom my_project import Foo
- "},
- predicted: indoc! {"
- import
- from my_project import Foo
- "},
- expected_reversal_chars: 0,
- expected_total_chars: 6,
- },
- Case {
- name: "non-word based reversal",
- original: "from",
- current: "ifrom",
- predicted: "from",
- expected_reversal_chars: 1,
- expected_total_chars: 1,
- },
- Case {
- name: "multiple insertions no reversal",
- original: "print(\"Hello, World!\")",
- current: "sys.(\"Hello, World!\")",
- predicted: "sys.stdout.write(\"Hello, World!\\n\")",
- expected_reversal_chars: 0,
- expected_total_chars: 14,
- },
- ];
-
- for case in &cases {
- let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
- assert_eq!(
- overlap.chars_reversing_user_edits, case.expected_reversal_chars,
- "Test '{}': expected {} reversal chars, got {}",
- case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
- );
- assert_eq!(
- overlap.total_chars_in_prediction, case.expected_total_chars,
- "Test '{}': expected {} total chars, got {}",
- case.name, case.expected_total_chars, overlap.total_chars_in_prediction
- );
- }
- }
-
- #[test]
- fn test_reverse_diff() {
- let forward_diff = indoc! {"
- --- a/file.rs
- +++ b/file.rs
- @@ -1,3 +1,4 @@
- fn main() {
- + let x = 42;
- println!(\"hello\");
- }"};
-
- let reversed = reverse_diff(forward_diff);
-
- assert!(
- reversed.contains("+++ a/file.rs"),
- "Should have +++ for old path"
- );
- assert!(
- reversed.contains("--- b/file.rs"),
- "Should have --- for new path"
- );
- assert!(
- reversed.contains("- let x = 42;"),
- "Added line should become deletion"
- );
- assert!(
- reversed.contains(" fn main()"),
- "Context lines should be unchanged"
- );
- }
-
- #[test]
- fn test_reverse_diff_roundtrip() {
- // Applying a diff and then its reverse should get back to original
- let original = indoc! {"
- first line
- hello world
- last line
- "};
- let modified = indoc! {"
- first line
- hello beautiful world
- last line
- "};
-
- // unified_diff doesn't include file headers, but apply_diff_to_string needs them
- let diff_body = language::unified_diff(original, modified);
- let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
- let reversed_diff = reverse_diff(&forward_diff);
-
- // Apply forward diff to original
- let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
- assert_eq!(after_forward, modified);
-
- // Apply reversed diff to modified
- let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
- assert_eq!(after_reverse, original);
- }
-
- #[test]
- fn test_filter_edit_history_by_path() {
- // Test that filter_edit_history_by_path correctly matches paths when
- // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
- // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
- let events = vec![
- Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("myrepo/src/file.rs")),
- old_path: Arc::from(Path::new("myrepo/src/file.rs")),
- diff: indoc! {"
- @@ -1 +1 @@
- -old
- +new"}
- .into(),
- predicted: false,
- in_open_source_repo: true,
- }),
- Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("myrepo/other.rs")),
- old_path: Arc::from(Path::new("myrepo/other.rs")),
- diff: indoc! {"
- @@ -1 +1 @@
- -a
- +b"}
- .into(),
- predicted: false,
- in_open_source_repo: true,
- }),
- Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("src/file.rs")),
- old_path: Arc::from(Path::new("src/file.rs")),
- diff: indoc! {"
- @@ -1 +1 @@
- -x
- +y"}
- .into(),
- predicted: false,
- in_open_source_repo: true,
- }),
- ];
-
- // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
- // "src/file.rs" exact match
- let cursor_path = Path::new("src/file.rs");
- let filtered = filter_edit_history_by_path(&events, cursor_path);
- assert_eq!(
- filtered.len(),
- 2,
- "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
- );
-
- // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
- // "src/file.rs" stripped -> "file.rs" == "file.rs"
- let cursor_path = Path::new("file.rs");
- let filtered = filter_edit_history_by_path(&events, cursor_path);
- assert_eq!(
- filtered.len(),
- 1,
- "Should only match src/file.rs (stripped to file.rs)"
- );
-
- // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
- let cursor_path = Path::new("other.rs");
- let filtered = filter_edit_history_by_path(&events, cursor_path);
- assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
- }
-
- #[test]
- fn test_reverse_diff_preserves_trailing_newline() {
- let diff_with_trailing_newline = indoc! {"
- --- a/file
- +++ b/file
- @@ -1 +1 @@
- -old
- +new
- "};
- let reversed = reverse_diff(diff_with_trailing_newline);
- assert!(
- reversed.ends_with('\n'),
- "Reversed diff should preserve trailing newline"
- );
-
- let diff_without_trailing_newline = indoc! {"
- --- a/file
- +++ b/file
- @@ -1 +1 @@
- -old
- +new"};
- let reversed = reverse_diff(diff_without_trailing_newline);
- assert!(
- !reversed.ends_with('\n'),
- "Reversed diff should not add trailing newline if original didn't have one"
- );
- }
-
- #[test]
- fn test_filter_hunks_by_excerpt_region() {
- struct Case {
- name: &'static str,
- diff: &'static str,
- excerpt_start_row: u32,
- excerpt_row_count: u32,
- expected_filtered_diff: &'static str,
- expected_line_offset: i32,
- }
-
- let cases = [
- Case {
- name: "hunk_entirely_before_excerpt",
- diff: indoc! {"
- @@ -1,3 +1,4 @@
- line1
- +inserted
- line2
- line3
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 5,
- expected_filtered_diff: "",
- expected_line_offset: 1,
- },
- Case {
- name: "hunk_entirely_inside_excerpt",
- diff: indoc! {"
- @@ -12,3 +12,4 @@
- line12
- +inserted
- line13
- line14
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 10,
- expected_filtered_diff: indoc! {"
- @@ -2,3 +2,4 @@
- line12
- +inserted
- line13
- line14
- "},
- expected_line_offset: 1,
- },
- Case {
- name: "hunk_entirely_after_excerpt",
- diff: indoc! {"
- @@ -50,3 +50,4 @@
- line50
- +inserted
- line51
- line52
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 5,
- expected_filtered_diff: "",
- expected_line_offset: 0,
- },
- Case {
- name: "hunk_straddles_excerpt_start",
- diff: indoc! {"
- @@ -8,5 +8,6 @@
- line8
- line9
- +inserted
- line10
- line11
- line12
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 10,
- expected_filtered_diff: indoc! {"
- @@ -1,3 +1,3 @@
- line10
- line11
- line12
- "},
- expected_line_offset: 1,
- },
- Case {
- name: "hunk_straddles_excerpt_end",
- diff: indoc! {"
- @@ -18,5 +18,6 @@
- line18
- line19
- +inserted
- line20
- line21
- line22
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 10,
- expected_filtered_diff: indoc! {"
- @@ -8,2 +8,3 @@
- line18
- line19
- +inserted
- "},
- expected_line_offset: 1,
- },
- Case {
- name: "multiple_hunks_mixed",
- diff: indoc! {"
- @@ -1,2 +1,3 @@
- line1
- +before_excerpt
- line2
- @@ -12,2 +13,3 @@
- line12
- +inside_excerpt
- line13
- @@ -50,2 +52,3 @@
- line50
- +after_excerpt
- line51
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 10,
- expected_filtered_diff: indoc! {"
- @@ -3,2 +3,3 @@
- line12
- +inside_excerpt
- line13
- "},
- expected_line_offset: 2,
- },
- Case {
- name: "deletion_before_excerpt",
- diff: indoc! {"
- @@ -1,4 +1,3 @@
- line1
- -deleted
- line2
- line3
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 5,
- expected_filtered_diff: "",
- expected_line_offset: -1,
- },
- Case {
- name: "deletion_inside_excerpt",
- diff: indoc! {"
- @@ -12,4 +12,3 @@
- line12
- -deleted
- line13
- line14
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 10,
- expected_filtered_diff: indoc! {"
- @@ -2,4 +2,3 @@
- line12
- -deleted
- line13
- line14
- "},
- expected_line_offset: -1,
- },
- Case {
- name: "empty_diff",
- diff: "",
- excerpt_start_row: 10,
- excerpt_row_count: 5,
- expected_filtered_diff: "",
- expected_line_offset: 0,
- },
- Case {
- name: "hunk_spans_entire_excerpt",
- diff: indoc! {"
- @@ -8,10 +8,12 @@
- line8
- line9
- line10
- line11
- +inserted1
- line12
- line13
- +inserted2
- line14
- line15
- line16
- line17
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 5,
- expected_filtered_diff: indoc! {"
- @@ -1,3 +1,5 @@
- line11
- +inserted1
- line12
- line13
- +inserted2
- "},
- expected_line_offset: 2,
- },
- Case {
- name: "replacement_inside_excerpt",
- diff: indoc! {"
- @@ -12,3 +12,3 @@
- line12
- -old_text
- +new_text
- line14
- "},
- excerpt_start_row: 10,
- excerpt_row_count: 10,
- expected_filtered_diff: indoc! {"
- @@ -2,3 +2,3 @@
- line12
- -old_text
- +new_text
- line14
- "},
- expected_line_offset: 0,
- },
- ];
-
- for case in &cases {
- let (filtered, line_offset) = filter_diff_hunks_by_excerpt(
- case.diff,
- case.excerpt_start_row,
- case.excerpt_row_count,
- );
- assert_eq!(
- filtered, case.expected_filtered_diff,
- "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}",
- case.name, case.expected_filtered_diff, filtered
- );
- assert_eq!(
- line_offset, case.expected_line_offset,
- "Test '{}': line offset mismatch. Expected {}, got {}",
- case.name, case.expected_line_offset, line_offset
- );
- }
- }
-
- #[test]
- fn test_excerpt_aware_reversal_tracking() {
- struct Case {
- name: &'static str,
- edit_history_diffs: Vec<&'static str>,
- excerpt_content: &'static str,
- excerpt_start_row: u32,
- predicted_content: &'static str,
- expected_reversal_chars: usize,
- expected_total_chars: usize,
- }
-
- let cases = [
- Case {
- name: "edit_outside_excerpt_no_reversal",
- edit_history_diffs: vec![indoc! {"
- @@ -1,2 +1,3 @@
- line1
- +added_outside
- line2
- "}],
- excerpt_content: indoc! {"
- line10
- line11
- line12
- "},
- excerpt_start_row: 10,
- predicted_content: indoc! {"
- line10
- modified
- line12
- "},
- expected_reversal_chars: 0,
- expected_total_chars: 14,
- },
- Case {
- name: "edit_inside_excerpt_with_reversal",
- edit_history_diffs: vec![indoc! {"
- @@ -10,3 +10,4 @@
- line10
- +user_added
- line11
- line12
- "}],
- excerpt_content: indoc! {"
- line10
- user_added
- line11
- line12
- "},
- excerpt_start_row: 10,
- predicted_content: indoc! {"
- line10
- line11
- line12
- "},
- expected_reversal_chars: 11,
- expected_total_chars: 11,
- },
- Case {
- name: "straddling_edit_partial_reversal",
- edit_history_diffs: vec![indoc! {"
- @@ -8,6 +8,8 @@
- line8
- line9
- +before_excerpt
- line10
- +inside_excerpt
- line11
- line12
- line13
- "}],
- excerpt_content: indoc! {"
- line10
- inside_excerpt
- line11
- line12
- line13
- "},
- excerpt_start_row: 10,
- predicted_content: indoc! {"
- line10
- line11
- line12
- line13
- "},
- expected_reversal_chars: 15,
- expected_total_chars: 15,
- },
- Case {
- name: "multiple_edits_mixed_locations",
- edit_history_diffs: vec![
- indoc! {"
- @@ -1,2 +1,3 @@
- line1
- +outside1
- line2
- "},
- indoc! {"
- @@ -11,2 +12,3 @@
- line11
- +inside1
- line12
- "},
- ],
- excerpt_content: indoc! {"
- line10
- line11
- inside1
- line12
- line13
- "},
- excerpt_start_row: 10,
- predicted_content: indoc! {"
- line10
- line11
- line12
- line13
- "},
- expected_reversal_chars: 8,
- expected_total_chars: 8,
- },
- Case {
- name: "no_edit_history",
- edit_history_diffs: vec![],
- excerpt_content: indoc! {"
- line10
- line11
- line12
- "},
- excerpt_start_row: 10,
- predicted_content: indoc! {"
- line10
- modified
- line12
- "},
- expected_reversal_chars: 0,
- expected_total_chars: 14,
- },
- Case {
- name: "edit_after_excerpt_no_effect",
- edit_history_diffs: vec![indoc! {"
- @@ -50,2 +50,3 @@
- line50
- +added_after
- line51
- "}],
- excerpt_content: indoc! {"
- line10
- line11
- line12
- "},
- excerpt_start_row: 10,
- predicted_content: indoc! {"
- line10
- changed
- line12
- "},
- expected_reversal_chars: 0,
- expected_total_chars: 13,
- },
- Case {
- name: "line_offset_tracking_across_hunks",
- edit_history_diffs: vec![
- indoc! {"
- @@ -1,2 +1,4 @@
- line1
- +added1
- +added2
- line2
- "},
- indoc! {"
- @@ -12,2 +14,3 @@
- line12
- +inside_after_offset
- line13
- "},
- ],
- excerpt_content: indoc! {"
- line10
- line11
- line12
- inside_after_offset
- line13
- "},
- excerpt_start_row: 10,
- predicted_content: indoc! {"
- line10
- line11
- line12
- line13
- "},
- expected_reversal_chars: 20,
- expected_total_chars: 20,
- },
- ];
-
- for case in &cases {
- let overlap = compute_excerpt_aware_reversal_overlap(
- &case.edit_history_diffs,
- case.excerpt_content,
- case.excerpt_start_row,
- case.predicted_content,
- );
- assert_eq!(
- overlap.chars_reversing_user_edits, case.expected_reversal_chars,
- "Test '{}': expected {} reversal chars, got {}",
- case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
- );
- assert_eq!(
- overlap.total_chars_in_prediction, case.expected_total_chars,
- "Test '{}': expected {} total chars, got {}",
- case.name, case.expected_total_chars, overlap.total_chars_in_prediction
- );
- }
- }
-
- #[test]
- fn test_lenient_diff_application() {
- struct Case {
- name: &'static str,
- diff: &'static str,
- content: &'static str,
- expected_result: &'static str,
- }
-
- let cases = [
- Case {
- name: "hunk_context_not_found_skipped",
- diff: indoc! {"
- @@ -1,3 +1,4 @@
- context_not_in_content
- +added_line
- more_context
- final_context
- "},
- content: indoc! {"
- completely
- different
- content
- "},
- expected_result: indoc! {"
- completely
- different
- content
- "},
- },
- Case {
- name: "hunk_context_found_applied",
- diff: indoc! {"
- @@ -1,3 +1,4 @@
- line1
- +inserted
- line2
- line3
- "},
- content: indoc! {"
- line1
- line2
- line3
- "},
- expected_result: indoc! {"
- line1
- inserted
- line2
- line3
- "},
- },
- Case {
- name: "multiple_hunks_partial_match",
- diff: indoc! {"
- @@ -1,2 +1,3 @@
- not_found
- +skipped
- also_not_found
- @@ -5,2 +6,3 @@
- line5
- +applied
- line6
- "},
- content: indoc! {"
- line1
- line2
- line3
- line4
- line5
- line6
- "},
- expected_result: indoc! {"
- line1
- line2
- line3
- line4
- line5
- applied
- line6
- "},
- },
- Case {
- name: "empty_diff",
- diff: "",
- content: indoc! {"
- unchanged
- content
- "},
- expected_result: indoc! {"
- unchanged
- content
- "},
- },
- ];
-
- for case in &cases {
- let result = apply_diff_to_string_lenient(case.diff, case.content);
- assert_eq!(
- result, case.expected_result,
- "Test '{}': expected:\n{}\ngot:\n{}",
- case.name, case.expected_result, result
- );
- }
- }
-
- #[test]
- fn test_unicode_reversal_overlap() {
- struct Case {
- name: &'static str,
- original: &'static str,
- current: &'static str,
- predicted: &'static str,
- expected_reversal_chars: usize,
- expected_total_chars: usize,
- }
-
- let cases = [
- Case {
- name: "unicode_extension_cjk",
- original: "",
- current: "日", // 1 char
- predicted: "日本語", // 3 chars, adds 2 chars
- expected_reversal_chars: 0,
- expected_total_chars: 2, // "本語" = 2 chars added
- },
- Case {
- name: "unicode_extension_emoji",
- original: "",
- current: "🎉", // 1 char
- predicted: "🎉🎊🎈", // 3 chars, adds 2 chars
- expected_reversal_chars: 0,
- expected_total_chars: 2, // "🎊🎈" = 2 chars added
- },
- Case {
- name: "unicode_deletion_restored",
- original: "héllo wörld", // 11 chars
- current: "héllo", // 5 chars
- predicted: "héllo wörld", // restores " wörld" = 6 chars
- expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars
- expected_total_chars: 6,
- },
- Case {
- name: "unicode_addition_reversed",
- original: "café", // 4 chars
- current: "café latté", // 10 chars, added " latté" = 6 chars
- predicted: "café", // removes " latté"
- expected_reversal_chars: 6, // 6 chars removed
- expected_total_chars: 6,
- },
- Case {
- name: "mixed_ascii_unicode",
- original: "",
- current: "test日本", // 6 chars
- predicted: "test日本語です", // 9 chars
- expected_reversal_chars: 0,
- expected_total_chars: 3, // 3 new chars after subsequence normalization
- },
- Case {
- name: "unicode_replacement_not_subsequence",
- original: "",
- current: "日本", // 2 chars
- predicted: "中国", // 2 chars, different
- expected_reversal_chars: 2, // removes "日本" = 2 chars
- expected_total_chars: 4, // 2 removed + 2 added
- },
- ];
-
- for case in &cases {
- let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
- assert_eq!(
- overlap.chars_reversing_user_edits, case.expected_reversal_chars,
- "Test '{}': expected {} reversal chars, got {}",
- case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
- );
- assert_eq!(
- overlap.total_chars_in_prediction, case.expected_total_chars,
- "Test '{}': expected {} total chars, got {}",
- case.name, case.expected_total_chars, overlap.total_chars_in_prediction
- );
- }
- }
-
- #[test]
- fn test_compute_lcs_length() {
- assert_eq!(compute_lcs_length("", ""), 0);
- assert_eq!(compute_lcs_length("abc", ""), 0);
- assert_eq!(compute_lcs_length("", "abc"), 0);
- assert_eq!(compute_lcs_length("abc", "abc"), 3);
- assert_eq!(compute_lcs_length("abc", "def"), 0);
- assert_eq!(compute_lcs_length("abcdef", "ace"), 3);
- assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4);
- assert_eq!(compute_lcs_length("日本語", "日語"), 2);
- }
-
- #[test]
- fn test_compute_prediction_reversal_ratio_full_file() {
- let prompt_inputs = make_test_prompt_inputs(
- indoc! {"
- line1
- user_added
- line2
- "},
- vec![Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("src/test.rs")),
- old_path: Arc::from(Path::new("src/test.rs")),
- diff: indoc! {"
- @@ -1,2 +1,3 @@
- line1
- +user_added
- line2
- "}
- .into(),
- predicted: false,
- in_open_source_repo: false,
- })],
- None,
- );
-
- let predicted = indoc! {"
- line1
- line2
- "};
- let ratio =
- compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
-
- assert!(
- ratio > 0.9,
- "Expected high reversal ratio when prediction removes user addition, got {}",
- ratio
- );
- }
-
- #[test]
- fn test_compute_prediction_reversal_ratio_with_excerpt() {
- let prompt_inputs = make_test_prompt_inputs(
- indoc! {"
- line10
- user_added
- line11
- "},
- vec![Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("src/test.rs")),
- old_path: Arc::from(Path::new("src/test.rs")),
- diff: indoc! {"
- @@ -10,2 +10,3 @@
- line10
- +user_added
- line11
- "}
- .into(),
- predicted: false,
- in_open_source_repo: false,
- })],
- Some(10),
- );
-
- let predicted = indoc! {"
- line10
- line11
- "};
- let ratio =
- compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
-
- assert!(
- ratio > 0.9,
- "Expected high reversal ratio for excerpt-aware computation, got {}",
- ratio
- );
- }
-
- #[test]
- fn test_compute_prediction_reversal_ratio_no_history() {
- let prompt_inputs = make_test_prompt_inputs(
- indoc! {"
- original content
- "},
- vec![],
- None,
- );
-
- let predicted = indoc! {"
- completely different
- "};
- let ratio =
- compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
-
- assert_eq!(
- ratio, 0.0,
- "Expected zero reversal ratio with no edit history"
- );
- }
-
- #[test]
- fn test_compute_prediction_reversal_ratio_path_filtering() {
- let prompt_inputs = make_test_prompt_inputs(
- indoc! {"
- line1
- user_added
- line2
- "},
- vec![Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("src/other.rs")),
- old_path: Arc::from(Path::new("src/other.rs")),
- diff: indoc! {"
- @@ -1,2 +1,3 @@
- line1
- +user_added
- line2
- "}
- .into(),
- predicted: false,
- in_open_source_repo: false,
- })],
- None,
- );
-
- let predicted = indoc! {"
- line1
- line2
- "};
- let ratio =
- compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
-
- assert_eq!(
- ratio, 0.0,
- "Expected zero reversal when edit history is for different file"
- );
- }
-
- #[test]
- fn test_compute_prediction_reversal_ratio_lenient_fallback() {
- let prompt_inputs = make_test_prompt_inputs(
- indoc! {"
- actual_line1
- user_added
- actual_line2
- "},
- vec![Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("src/test.rs")),
- old_path: Arc::from(Path::new("src/test.rs")),
- diff: indoc! {"
- @@ -1,2 +1,3 @@
- wrong_context
- +user_added
- more_wrong
- "}
- .into(),
- predicted: false,
- in_open_source_repo: false,
- })],
- None,
- );
-
- let predicted = indoc! {"
- actual_line1
- actual_line2
- "};
- let ratio =
- compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
-
- assert!(
- ratio >= 0.0 && ratio <= 1.0,
- "Ratio should be valid even with lenient fallback, got {}",
- ratio
- );
- }
-
- #[test]
- fn test_excerpt_aware_reversal_error_recovery() {
- let diffs = vec![indoc! {"
- @@ -1,2 +1,3 @@
- nonexistent_context
- +added
- more_nonexistent
- "}];
- let excerpt_content = indoc! {"
- completely
- different
- content
- "};
- let predicted_content = indoc! {"
- completely
- modified
- content
- "};
-
- let overlap =
- compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content);
-
- assert!(
- overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0,
- "Should handle failed diff application gracefully"
- );
- }
-
- #[test]
- fn test_only_most_recent_edit_tracked() {
- let prompt_inputs = make_test_prompt_inputs(
- indoc! {"
- line1
- first_add
- second_add
- line2
- "},
- vec![
- Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("src/test.rs")),
- old_path: Arc::from(Path::new("src/test.rs")),
- diff: indoc! {"
- @@ -1,2 +1,3 @@
- line1
- +first_add
- line2
- "}
- .into(),
- predicted: false,
- in_open_source_repo: false,
- }),
- Arc::new(zeta_prompt::Event::BufferChange {
- path: Arc::from(Path::new("src/test.rs")),
- old_path: Arc::from(Path::new("src/test.rs")),
- diff: indoc! {"
- @@ -2,2 +2,3 @@
- first_add
- +second_add
- line2
- "}
- .into(),
- predicted: false,
- in_open_source_repo: false,
- }),
- ],
- None,
- );
-
- let predicted = indoc! {"
- line1
- first_add
- line2
- "};
- let ratio =
- compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
-
- assert!(
- ratio > 0.9,
- "Expected high reversal ratio when prediction exactly reverses the most recent edit, got {}",
- ratio
- );
- }
+ edit_prediction_metrics::compute_prediction_reversal_ratio_from_history(
+ prompt_inputs.cursor_excerpt.as_ref(),
+ &prompt_inputs.events,
+ prompt_inputs.excerpt_start_row,
+ predicted_content,
+ cursor_path,
+ )
}
@@ -0,0 +1,23 @@
+[package]
+name = "edit_prediction_metrics"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/edit_prediction_metrics.rs"
+
+[dependencies]
+language.workspace = true
+serde.workspace = true
+similar = "2.7.0"
+tree-sitter.workspace = true
+zeta_prompt.workspace = true
+
+[dev-dependencies]
+indoc.workspace = true
+pretty_assertions.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,24 @@
+mod kept_rate;
+mod patch_metrics;
+mod reversal;
+mod tokenize;
+mod tree_sitter;
+
+pub use kept_rate::KeptRateResult;
+#[cfg(test)]
+pub use kept_rate::TokenAnnotation;
+pub use kept_rate::compute_kept_rate;
+pub use patch_metrics::ClassificationMetrics;
+pub use patch_metrics::Counts;
+pub use patch_metrics::DeltaChrFMetrics;
+pub use patch_metrics::TokenChangeCounts;
+pub use patch_metrics::braces_disbalance;
+pub use patch_metrics::count_patch_token_changes;
+pub use patch_metrics::delta_chr_f;
+pub use patch_metrics::delta_chr_f_beta;
+pub use patch_metrics::exact_lines_match;
+pub use patch_metrics::extract_changed_lines_from_diff;
+pub use patch_metrics::has_isolated_whitespace_changes;
+pub use patch_metrics::is_editable_region_correct;
+pub use reversal::compute_prediction_reversal_ratio_from_history;
+pub use tree_sitter::count_tree_sitter_errors;
@@ -1,9 +1,10 @@
-use crate::metrics::tokenize;
+use crate::tokenize::tokenize;
+use serde::Serialize;
const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
#[cfg(test)]
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum TokenAnnotation {
Context,
Kept,
@@ -11,7 +12,7 @@ pub enum TokenAnnotation {
}
#[allow(dead_code)]
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Serialize)]
pub struct KeptRateResult {
/// Characters newly introduced by the candidate
pub candidate_new_chars: usize,
@@ -0,0 +1,1451 @@
+use std::collections::HashMap;
+
+use crate::tokenize::tokenize;
+use serde::Serialize;
+use similar::{DiffTag, TextDiff};
+
+pub type Counts = HashMap<String, usize>;
+type CountsDelta = HashMap<String, isize>;
+
+/// Context characters needed on each side of a change to capture all affected n-grams
+const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1;
+
+#[derive(Default, Debug, Clone, Serialize)]
+pub struct ClassificationMetrics {
+ pub true_positives: usize,
+ pub false_positives: usize,
+ pub false_negatives: usize,
+}
+
+impl ClassificationMetrics {
+ pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
+ let mut true_positives = 0;
+ let mut false_positives = 0;
+ let mut false_negatives = 0;
+
+ for (ngram, &expected_count) in expected {
+ let actual_count = *actual.get(ngram).unwrap_or(&0);
+ if actual_count > expected_count {
+ false_positives += actual_count - expected_count;
+ } else {
+ false_negatives += expected_count - actual_count;
+ }
+ true_positives += expected_count.min(actual_count);
+ }
+
+ for (ngram, &actual_count) in actual {
+ if !expected.contains_key(ngram) {
+ false_positives += actual_count;
+ }
+ }
+
+ ClassificationMetrics {
+ true_positives,
+ false_positives,
+ false_negatives,
+ }
+ }
+
+ pub fn accumulate(&mut self, other: &ClassificationMetrics) {
+ self.true_positives += other.true_positives;
+ self.false_positives += other.false_positives;
+ self.false_negatives += other.false_negatives;
+ }
+
+ pub fn precision(&self) -> f64 {
+ if self.true_positives + self.false_positives == 0 {
+ 0.0
+ } else {
+ self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
+ }
+ }
+
+ pub fn recall(&self) -> f64 {
+ if self.true_positives + self.false_negatives == 0 {
+ 0.0
+ } else {
+ self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
+ }
+ }
+
+ pub fn f1(&self) -> f64 {
+ let precision = self.precision();
+ let recall = self.recall();
+ if precision + recall == 0.0 {
+ 0.0
+ } else {
+ 2.0 * precision * recall / (precision + recall)
+ }
+ }
+}
+
+enum ChrfWhitespace {
+ /// Preserve whitespace as-is
+ #[allow(unused)]
+ Unchanged,
+
+ /// Ignore all whitespace differences
+ #[allow(unused)]
+ Ignore,
+
+ /// Collapse whitespace into single spaces
+ Collapse,
+}
+
+const CHR_F_CHAR_ORDER: usize = 6;
+const CHR_F_BETA: f64 = 0.5;
+const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse;
+
+pub fn delta_chr_f_beta() -> f64 {
+ CHR_F_BETA
+}
+
+#[derive(Default, Debug, Clone, Serialize)]
+pub struct DeltaChrFMetrics {
+ pub score: f64,
+ pub beta: f64,
+ pub counts: ClassificationMetrics,
+ pub precision: f64,
+ pub recall: f64,
+}
+
+/// Computes delta-chrF metrics that compare two sets of edits.
+///
+/// This metric works by:
+/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
+/// 2. Comparing these deltas to measure how well actual edits match expected edits
+///
+/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
+/// the expected edits.
+pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
+ if original == expected && expected == actual {
+ return DeltaChrFMetrics {
+ score: 100.0,
+ beta: CHR_F_BETA,
+ precision: 1.0,
+ recall: 1.0,
+ ..DeltaChrFMetrics::default()
+ };
+ }
+
+ let orig_chars: Vec<char> = filter_whitespace_chars(original);
+ let exp_chars: Vec<char> = filter_whitespace_chars(expected);
+ let act_chars: Vec<char> = filter_whitespace_chars(actual);
+
+ // Find the changed regions between original→expected and original→actual
+ // We only need to compute n-grams on these regions (plus context for boundary n-grams)
+ let (orig_for_exp, exp_region) = extract_changed_regions(&orig_chars, &exp_chars);
+ let (orig_for_act, act_region) = extract_changed_regions(&orig_chars, &act_chars);
+
+ let mut total_precision = 0.0;
+ let mut total_recall = 0.0;
+ let mut total_counts = ClassificationMetrics::default();
+
+ for order in 1..=CHR_F_CHAR_ORDER {
+ let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order);
+ let exp_ngrams = count_ngrams_from_chars(&exp_region, order);
+ let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp);
+
+ let orig_ngrams_for_act = count_ngrams_from_chars(&orig_for_act, order);
+ let act_ngrams = count_ngrams_from_chars(&act_region, order);
+ let actual_delta = compute_ngram_delta(&act_ngrams, &orig_ngrams_for_act);
+
+ if expected_delta.is_empty() && actual_delta.is_empty() {
+ total_precision += 1.0;
+ total_recall += 1.0;
+ continue;
+ }
+
+ let expected_counts = ngram_delta_to_counts(&expected_delta);
+ let actual_counts = ngram_delta_to_counts(&actual_delta);
+
+ let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
+ total_precision += counts.precision();
+ total_recall += counts.recall();
+ total_counts.accumulate(&counts);
+ }
+
+ let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
+ let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
+ let score = if average_precision + average_recall == 0.0 {
+ 0.0
+ } else {
+ (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
+ / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
+ * 100.0
+ };
+
+ DeltaChrFMetrics {
+ score,
+ beta: CHR_F_BETA,
+ counts: total_counts,
+ precision: average_precision,
+ recall: average_recall,
+ }
+}
+
+/// Reference implementation of delta-chrF metrics (original, non-optimized version).
+/// Used for testing that the optimized version produces identical results.
+#[cfg(test)]
+fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
+ if original == expected && expected == actual {
+ return DeltaChrFMetrics {
+ score: 100.0,
+ beta: CHR_F_BETA,
+ precision: 1.0,
+ recall: 1.0,
+ ..DeltaChrFMetrics::default()
+ };
+ }
+
+ let original_ngrams = chr_f_ngram_counts(original);
+ let expected_ngrams = chr_f_ngram_counts(expected);
+ let actual_ngrams = chr_f_ngram_counts(actual);
+
+ let mut total_precision = 0.0;
+ let mut total_recall = 0.0;
+ let mut total_counts = ClassificationMetrics::default();
+
+ for order in 0..CHR_F_CHAR_ORDER {
+ let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
+ let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
+
+ if expected_delta.is_empty() && actual_delta.is_empty() {
+ total_precision += 1.0;
+ total_recall += 1.0;
+ continue;
+ }
+
+ let expected_counts = ngram_delta_to_counts(&expected_delta);
+ let actual_counts = ngram_delta_to_counts(&actual_delta);
+
+ let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
+ total_precision += counts.precision();
+ total_recall += counts.recall();
+ total_counts.accumulate(&counts);
+ }
+
+ let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
+ let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
+ let score = if average_precision + average_recall == 0.0 {
+ 0.0
+ } else {
+ (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
+ / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
+ * 100.0
+ };
+
+ DeltaChrFMetrics {
+ score,
+ beta: CHR_F_BETA,
+ counts: total_counts,
+ precision: average_precision,
+ recall: average_recall,
+ }
+}
+
+/// Filter whitespace from a string and return as Vec<char>
+fn filter_whitespace_chars(text: &str) -> Vec<char> {
+ match CHR_F_WHITESPACE {
+ ChrfWhitespace::Unchanged => text.chars().collect(),
+ ChrfWhitespace::Ignore => text.chars().filter(|c| !c.is_whitespace()).collect(),
+ ChrfWhitespace::Collapse => collapse_whitespace(text.chars()),
+ }
+}
+
+/// Collapse whitespace into single spaces.
+/// Newlines and spaces are collapsed separately.
+fn collapse_whitespace(chars: impl Iterator<Item = char>) -> Vec<char> {
+ let mut result = Vec::new();
+ let mut last_whitespace = None;
+ for c in chars {
+ if c.is_whitespace() && c != '\n' {
+ if last_whitespace != Some(' ') {
+ result.push(' ');
+ last_whitespace = Some(' ');
+ }
+ } else if c == '\n' {
+ if last_whitespace != Some('\n') {
+ result.push(c);
+ last_whitespace = Some('\n');
+ }
+ } else {
+ result.push(c);
+ last_whitespace = None;
+ }
+ }
+ result
+}
+
+/// Extract only the changed regions between two texts, with context for n-gram boundaries.
+///
+/// Returns (original_affected_region, modified_affected_region) as Vec<char>.
+///
+/// The key insight: when computing n-gram delta between two nearly-identical texts,
+/// n-grams from unchanged regions cancel out. We only need to process:
+/// 1. The changed content itself
+/// 2. CONTEXT_CHARS (n-1) characters before and after, to capture boundary-crossing n-grams
+fn extract_changed_regions(original: &[char], modified: &[char]) -> (Vec<char>, Vec<char>) {
+ // Find longest common prefix
+ let prefix_len = original
+ .iter()
+ .zip(modified.iter())
+ .take_while(|(a, b)| a == b)
+ .count();
+
+ // Find longest common suffix (that doesn't overlap with prefix)
+ let orig_remaining = original.len().saturating_sub(prefix_len);
+ let mod_remaining = modified.len().saturating_sub(prefix_len);
+ let max_suffix = orig_remaining.min(mod_remaining);
+
+ let suffix_len = original
+ .iter()
+ .rev()
+ .zip(modified.iter().rev())
+ .take(max_suffix)
+ .take_while(|(a, b)| a == b)
+ .count();
+
+ // Calculate the changed region boundaries
+ let orig_change_start = prefix_len;
+ let orig_change_end = original.len().saturating_sub(suffix_len);
+ let mod_change_start = prefix_len;
+ let mod_change_end = modified.len().saturating_sub(suffix_len);
+
+ // If there's no actual change, return empty regions
+ if orig_change_start >= orig_change_end && mod_change_start >= mod_change_end {
+ return (Vec::new(), Vec::new());
+ }
+
+ // Expand to include context for n-gram boundaries
+ let orig_context_start = orig_change_start.saturating_sub(CONTEXT_CHARS);
+ let orig_context_end = (orig_change_end + CONTEXT_CHARS).min(original.len());
+ let mod_context_start = mod_change_start.saturating_sub(CONTEXT_CHARS);
+ let mod_context_end = (mod_change_end + CONTEXT_CHARS).min(modified.len());
+
+ let orig_region: Vec<char> = original[orig_context_start..orig_context_end].to_vec();
+ let mod_region: Vec<char> = modified[mod_context_start..mod_context_end].to_vec();
+
+ (orig_region, mod_region)
+}
+
+/// Count n-grams directly from a char slice (avoids String allocation for the full text)
+fn count_ngrams_from_chars(chars: &[char], n: usize) -> Counts {
+ let mut counts = Counts::default();
+
+ if chars.len() < n {
+ return counts;
+ }
+
+ for window in chars.windows(n) {
+ let ngram: String = window.iter().collect();
+ *counts.entry(ngram).or_insert(0) += 1;
+ }
+
+ counts
+}
+
+#[allow(dead_code)]
+fn chr_f_ngram_counts(text: &str) -> Vec<Counts> {
+ let text = match CHR_F_WHITESPACE {
+ ChrfWhitespace::Unchanged => text.to_string(),
+ ChrfWhitespace::Ignore => text
+ .chars()
+ .filter(|c| !c.is_whitespace())
+ .collect::<String>(),
+ ChrfWhitespace::Collapse => collapse_whitespace(text.chars())
+ .into_iter()
+ .collect::<String>(),
+ };
+
+ (1..=CHR_F_CHAR_ORDER)
+ .map(|order| count_ngrams(&text, order))
+ .collect()
+}
+
+fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta {
+ let mut delta = CountsDelta::default();
+
+ for (ngram, &before_count) in before {
+ let after_count = *after.get(ngram).unwrap_or(&0);
+ delta.insert(ngram.clone(), after_count as isize - before_count as isize);
+ }
+
+ for (ngram, &after_count) in after {
+ if !before.contains_key(ngram) {
+ delta.insert(ngram.clone(), after_count as isize);
+ }
+ }
+
+ delta
+}
+
+/// Convert negative counts to special deletion tokens.
+/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1},
+/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo"
+/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive.
+fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
+ let mut counts = Counts::default();
+
+ for (ngram, &delta) in delta {
+ if delta > 0 {
+ counts.insert(ngram.clone(), delta as usize);
+ } else if delta < 0 {
+ counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
+ }
+ }
+
+ counts
+}
+
+#[allow(dead_code)]
+fn count_ngrams(text: &str, n: usize) -> Counts {
+ let chars: Vec<char> = text.chars().collect();
+ let mut counts = Counts::default();
+
+ for window in chars.windows(n) {
+ let ngram: String = window.iter().collect();
+ *counts.entry(ngram).or_insert(0) += 1;
+ }
+
+ counts
+}
+
+pub fn braces_disbalance(text: &str) -> usize {
+ let mut disbalance = 0isize;
+
+ let a = text.chars().filter(|&c| c == '{').count() as isize;
+ let b = text.chars().filter(|&c| c == '}').count() as isize;
+ disbalance += (a - b).abs();
+
+ let a = text.chars().filter(|&c| c == '(').count() as isize;
+ let b = text.chars().filter(|&c| c == ')').count() as isize;
+ disbalance += (a - b).abs();
+
+ let a = text.chars().filter(|&c| c == '[').count() as isize;
+ let b = text.chars().filter(|&c| c == ']').count() as isize;
+ disbalance += (a - b).abs();
+
+ disbalance as usize
+}
+
+/// Extracts changed lines from a unified diff string.
+/// Returns a bag (multiset) of lines that were added (+) or removed (-).
+/// The +/- prefix is included in the line to distinguish additions from deletions.
+pub fn extract_changed_lines_from_diff(diff: &str) -> Counts {
+ let mut counts = Counts::default();
+
+ for line in diff.lines() {
+ // Skip file headers (--- and +++)
+ if line.starts_with("---") || line.starts_with("+++") {
+ continue;
+ }
+ // Skip hunk headers (@@)
+ if line.starts_with("@@") {
+ continue;
+ }
+ // Skip diff header lines (diff --git, index, etc.)
+ if line.starts_with("diff ") || line.starts_with("index ") {
+ continue;
+ }
+ // Include added and removed lines (with their prefix)
+ if line.starts_with('+') || line.starts_with('-') {
+ *counts.entry(line.to_string()).or_insert(0) += 1;
+ }
+ }
+
+ counts
+}
+
+/// Computes exact lines match metrics between expected and actual patches.
+/// Treats changed lines as a bag (multiset) - order is discarded but count matters.
+/// Returns ClassificationMetrics with TP/FP/FN counts.
+pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics {
+ let expected_lines = extract_changed_lines_from_diff(expected_patch);
+ let actual_lines = extract_changed_lines_from_diff(actual_patch);
+ ClassificationMetrics::from_counts(&expected_lines, &actual_lines)
+}
+
+/// Returns whether the patch contains any isolated whitespace-only changes.
+///
+/// A whitespace-only change is an added or deleted line whose content is empty or
+/// contains only whitespace. It is "isolated" when it is not adjacent to any
+/// substantive (non-whitespace) change within the same contiguous change group.
+pub fn has_isolated_whitespace_changes(patch_str: &str, cursor_row: Option<u32>) -> bool {
+ let patch = Patch::parse_unified_diff(patch_str);
+
+ let cursor_new_file_line = cursor_row.map(|row| (row + 1) as usize);
+
+ for hunk in &patch.hunks {
+ let lines = &hunk.lines;
+ let mut new_text_line = hunk.new_start as usize;
+
+ for (i, line) in lines.iter().enumerate() {
+ let content = match line {
+ PatchLine::Addition(s) => {
+ let addition_line = new_text_line;
+ new_text_line += 1;
+ if s.trim().is_empty() && cursor_new_file_line == Some(addition_line) {
+ continue;
+ }
+ s.as_str()
+ }
+ PatchLine::Deletion(s) => s.as_str(),
+ PatchLine::Context(_) => {
+ new_text_line += 1;
+ continue;
+ }
+ _ => continue,
+ };
+
+ if !content.trim().is_empty() {
+ continue;
+ }
+
+ if is_whitespace_change_isolated(lines, i) {
+ return true;
+ }
+ }
+ }
+
+ false
+}
+
+fn is_whitespace_change_isolated(lines: &[PatchLine], index: usize) -> bool {
+ // Look backward for a non-whitespace change before hitting a context line
+ for line in lines[..index].iter().rev() {
+ match line {
+ PatchLine::Addition(s) | PatchLine::Deletion(s) => {
+ if !s.trim().is_empty() {
+ return false;
+ }
+ }
+ _ => break,
+ }
+ }
+
+ // Look forward for a non-whitespace change before hitting a context line
+ for line in &lines[index + 1..] {
+ match line {
+ PatchLine::Addition(s) | PatchLine::Deletion(s) => {
+ if !s.trim().is_empty() {
+ return false;
+ }
+ }
+ _ => break,
+ }
+ }
+
+ true
+}
+
+/// A simple proxy for whether the prediction respects editable region.
+pub fn is_editable_region_correct(actual_patch: &str) -> bool {
+ // A typical sign of a wrong editable region: a bunch of lines deletion
+ // at the beginning or end of the patch.
+ let patch = Patch::parse_unified_diff(actual_patch);
+ if patch.hunks.is_empty() {
+ return true;
+ }
+
+ let hunk = &patch.hunks[0];
+ let mut deletions_at_start = 0;
+
+ for line in hunk.lines.iter() {
+ match line {
+ PatchLine::Deletion(_) => deletions_at_start += 1,
+ _ => break,
+ }
+ }
+
+ if deletions_at_start >= 3 {
+ return false;
+ }
+
+ true
+}
+
+#[derive(Debug, Default, Clone, Serialize)]
+pub struct TokenChangeCounts {
+ pub inserted_tokens: usize,
+ pub deleted_tokens: usize,
+}
+
+/// Counts the number of inserted and deleted tokens in a unified diff patch.
+///
+/// Tokens are words and whitespace sequences (as defined by `word_diff::tokenize`).
+/// Within each hunk, the old (`-`) and new (`+`) lines are compared at the token level
+/// using an LCS-based diff, so modified lines only count the actually changed tokens
+/// rather than the entire line.
+pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts {
+ let mut counts = TokenChangeCounts::default();
+ let mut old_lines: Vec<&str> = Vec::new();
+ let mut new_lines: Vec<&str> = Vec::new();
+
+ let flush =
+ |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, counts: &mut TokenChangeCounts| {
+ if old_lines.is_empty() && new_lines.is_empty() {
+ return;
+ }
+
+ let old_text: String = old_lines
+ .iter()
+ .map(|line| if line.len() > 1 { &line[1..] } else { "" })
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ let new_text: String = new_lines
+ .iter()
+ .map(|line| if line.len() > 1 { &line[1..] } else { "" })
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ let old_tokens = tokenize(&old_text);
+ let new_tokens = tokenize(&new_text);
+ let ops = diff_tokens(&old_tokens, &new_tokens);
+
+ for op in ops {
+ match op {
+ DiffOp::Equal(..) => {}
+ DiffOp::Delete(start, end) => {
+ counts.deleted_tokens += end - start;
+ }
+ DiffOp::Insert(start, end) => {
+ counts.inserted_tokens += end - start;
+ }
+ DiffOp::Replace {
+ old_start,
+ old_end,
+ new_start,
+ new_end,
+ } => {
+ counts.deleted_tokens += old_end - old_start;
+ counts.inserted_tokens += new_end - new_start;
+ }
+ }
+ }
+
+ old_lines.clear();
+ new_lines.clear();
+ };
+
+ for line in patch.lines() {
+ if line.starts_with("---")
+ || line.starts_with("+++")
+ || line.starts_with("@@")
+ || line.starts_with("diff ")
+ || line.starts_with("index ")
+ {
+ flush(&mut old_lines, &mut new_lines, &mut counts);
+ } else if line.starts_with('-') {
+ old_lines.push(line);
+ } else if line.starts_with('+') {
+ new_lines.push(line);
+ } else {
+ flush(&mut old_lines, &mut new_lines, &mut counts);
+ }
+ }
+
+ flush(&mut old_lines, &mut new_lines, &mut counts);
+ counts
+}
+
+#[allow(dead_code)]
+#[derive(Debug)]
+enum DiffOp {
+ Equal(usize, usize),
+ Delete(usize, usize),
+ Insert(usize, usize),
+ Replace {
+ old_start: usize,
+ old_end: usize,
+ new_start: usize,
+ new_end: usize,
+ },
+}
+
+fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec<DiffOp> {
+ let diff = TextDiff::from_slices(old, new);
+ diff.ops()
+ .iter()
+ .map(|op| {
+ let tag = op.tag();
+ let old_range = op.old_range();
+ let new_range = op.new_range();
+ match tag {
+ DiffTag::Equal => DiffOp::Equal(old_range.start, old_range.end),
+ DiffTag::Delete => DiffOp::Delete(old_range.start, old_range.end),
+ DiffTag::Insert => DiffOp::Insert(new_range.start, new_range.end),
+ DiffTag::Replace => DiffOp::Replace {
+ old_start: old_range.start,
+ old_end: old_range.end,
+ new_start: new_range.start,
+ new_end: new_range.end,
+ },
+ }
+ })
+ .collect()
+}
+
+#[derive(Debug, Default, Clone)]
+struct Patch {
+ hunks: Vec<Hunk>,
+}
+
+impl Patch {
+ fn parse_unified_diff(unified_diff: &str) -> Patch {
+ let mut current_file = String::new();
+ let mut is_filename_inherited = false;
+ let mut hunk = Hunk::default();
+ let mut patch = Patch::default();
+ let mut in_header = true;
+
+ for line in unified_diff.lines() {
+ if line.starts_with("--- ") || line.starts_with("+++") || line.starts_with("@@") {
+ in_header = false;
+ }
+
+ if in_header {
+ continue;
+ }
+
+ if line.starts_with("@@") {
+ if !hunk.lines.is_empty() {
+ patch.hunks.push(hunk);
+ }
+ hunk = Hunk::from_header(line, ¤t_file, is_filename_inherited);
+ is_filename_inherited = true;
+ } else if let Some(path) = line.strip_prefix("--- ") {
+ is_filename_inherited = false;
+ let path = path.trim().strip_prefix("a/").unwrap_or(path);
+ if path != "/dev/null" {
+ current_file = path.into();
+ }
+ } else if let Some(path) = line.strip_prefix("+++ ") {
+ is_filename_inherited = false;
+ let path = path.trim().strip_prefix("b/").unwrap_or(path);
+ if path != "/dev/null" {
+ current_file = path.into();
+ }
+ } else if let Some(line) = line.strip_prefix('+') {
+ hunk.lines.push(PatchLine::Addition(line.to_string()));
+ } else if let Some(line) = line.strip_prefix('-') {
+ hunk.lines.push(PatchLine::Deletion(line.to_string()));
+ } else if let Some(line) = line.strip_prefix(' ') {
+ hunk.lines.push(PatchLine::Context(line.to_string()));
+ } else {
+ hunk.lines.push(PatchLine::Garbage(line.to_string()));
+ }
+ }
+
+ if !hunk.lines.is_empty() {
+ patch.hunks.push(hunk);
+ }
+
+ patch
+ }
+}
+
+#[derive(Debug, Default, Clone)]
+struct Hunk {
+ new_start: isize,
+ lines: Vec<PatchLine>,
+}
+
+impl Hunk {
+ fn from_header(header: &str, _filename: &str, _is_filename_inherited: bool) -> Self {
+ let (_, _, new_start, _, _) = Self::parse_hunk_header(header);
+ Self {
+ new_start,
+ lines: Vec::new(),
+ }
+ }
+
+ fn parse_hunk_header(line: &str) -> (isize, isize, isize, isize, String) {
+ let header_part = line.trim_start_matches("@@").trim();
+ let parts: Vec<&str> = header_part.split_whitespace().collect();
+
+ if parts.len() < 2 {
+ return (0, 0, 0, 0, String::new());
+ }
+
+ let old_part = parts[0].trim_start_matches('-');
+ let new_part = parts[1].trim_start_matches('+');
+
+ let (old_start, old_count) = Hunk::parse_hunk_header_range(old_part);
+ let (new_start, new_count) = Hunk::parse_hunk_header_range(new_part);
+
+ let comment = if parts.len() > 2 {
+ parts[2..]
+ .join(" ")
+ .trim_start_matches("@@")
+ .trim()
+ .to_string()
+ } else {
+ String::new()
+ };
+
+ (
+ old_start as isize,
+ old_count as isize,
+ new_start as isize,
+ new_count as isize,
+ comment,
+ )
+ }
+
+ fn parse_hunk_header_range(part: &str) -> (usize, usize) {
+ if let Some((start, count)) = part.split_once(',') {
+ (start.parse().unwrap_or(0), count.parse().unwrap_or(0))
+ } else {
+ (part.parse().unwrap_or(0), 1)
+ }
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+enum PatchLine {
+ Context(String),
+ Addition(String),
+ Deletion(String),
+ Garbage(String),
+}
+
+#[cfg(test)]
+mod test_optimization {
+ use super::*;
+
+ #[test]
+ fn test_extract_changed_regions_simple() {
+ let original: Vec<char> = "hello world".chars().collect();
+ let modified: Vec<char> = "hello there".chars().collect();
+
+ let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
+
+ // "world" vs "there" - with 5 chars context, we get "ello world" vs "ello there"
+ // (or less if not enough chars available)
+ assert!(orig_region.len() < original.len());
+ assert!(mod_region.len() < modified.len());
+ }
+
+ #[test]
+ fn test_extract_changed_regions_insertion() {
+ let original: Vec<char> = "abcdef".chars().collect();
+ let modified: Vec<char> = "abcXYZdef".chars().collect();
+
+ let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
+
+ // The insertion is between c and d, so we need context around that point
+ assert!(orig_region.len() <= original.len());
+ assert!(mod_region.iter().collect::<String>().contains("XYZ"));
+ }
+
+ #[test]
+ fn test_extract_changed_regions_identical() {
+ let text: Vec<char> = "identical text".chars().collect();
+
+ let (orig_region, mod_region) = extract_changed_regions(&text, &text);
+
+ // When texts are identical, regions should be empty
+ assert!(orig_region.is_empty());
+ assert!(mod_region.is_empty());
+ }
+
+ #[test]
+ fn test_optimized_matches_original_score() {
+ // Test that our optimized version produces the same results
+ let test_cases = vec![
+ ("hello world", "hello there", "hello world"),
+ (
+ "fn main() {}",
+ "fn main() { println!(); }",
+ "fn main() { print!(); }",
+ ),
+ ("abcdefghij", "abcXXXghij", "abcYYghij"),
+ ("unchanged", "unchanged", "unchanged"),
+ (
+ "prefix middle suffix",
+ "prefix CHANGED suffix",
+ "prefix middle suffix",
+ ),
+ ];
+
+ for (original, expected, actual) in test_cases {
+ let score = delta_chr_f(original, expected, actual).score;
+ // Just verify it produces a reasonable score (0-100)
+ assert!(
+ score >= 0.0 && score <= 100.0,
+ "Score {} out of range for ({}, {}, {})",
+ score,
+ original,
+ expected,
+ actual
+ );
+ }
+ }
+
+ #[test]
+ fn test_optimized_equals_reference() {
+ // Comprehensive test that optimized version matches reference implementation exactly
+ let test_cases = vec![
+ // Basic cases
+ ("hello world", "hello there", "hello world"),
+ ("hello world", "hello there", "hello there"),
+ ("unchanged", "unchanged", "unchanged"),
+ // Code-like cases
+ (
+ "fn main() { println!(\"Hello\"); }",
+ "fn main() { println!(\"Hello, World!\"); }",
+ "fn main() { println!(\"Hello, World!\"); }",
+ ),
+ (
+ "fn main() { println!(\"Hello\"); }",
+ "fn main() { println!(\"Hello, World!\"); }",
+ "fn main() { println!(\"Goodbye\"); }",
+ ),
+ // Insertion
+ ("abcdef", "abcXYZdef", "abcdef"),
+ ("abcdef", "abcXYZdef", "abcXYZdef"),
+ ("abcdef", "abcXYZdef", "abcABCdef"),
+ // Deletion
+ ("abcXYZdef", "abcdef", "abcXYZdef"),
+ ("abcXYZdef", "abcdef", "abcdef"),
+ // Multiple changes (simulated by different expected/actual)
+ ("one two three four", "one THREE four", "one two FOUR"),
+ // Edge cases
+ ("a", "b", "c"),
+ ("", "abc", ""),
+ ("abc", "", "abc"),
+ // Longer text with small change
+ (
+ "This is a longer piece of text that contains many words and characters to process",
+ "This is a longer piece of TEXT that contains many words and characters to process",
+ "This is a longer piece of text that contains many words and characters to process",
+ ),
+ // Change at the beginning
+ (
+ "ORIGINAL start of text",
+ "NEW start of text",
+ "DIFFERENT start of text",
+ ),
+ // Change at the end
+ (
+ "text ending ORIGINAL",
+ "text ending NEW",
+ "text ending DIFFERENT",
+ ),
+ // Whitespace (should be ignored)
+ ("hello world", "hello there", "hello world"),
+ ("a b c d", "a X c d", "a Y c d"),
+ ];
+
+ for (original, expected, actual) in test_cases {
+ let optimized_metrics = delta_chr_f(original, expected, actual);
+ let reference_metrics = delta_chr_f_reference(original, expected, actual);
+
+ assert!(
+ (optimized_metrics.score - reference_metrics.score).abs() < 1e-10,
+ "Score mismatch for ({:?}, {:?}, {:?}):\n optimized: {}\n reference: {}",
+ original,
+ expected,
+ actual,
+ optimized_metrics.score,
+ reference_metrics.score
+ );
+ assert_eq!(
+ optimized_metrics.counts.true_positives,
+ reference_metrics.counts.true_positives
+ );
+ assert_eq!(
+ optimized_metrics.counts.false_positives,
+ reference_metrics.counts.false_positives
+ );
+ assert_eq!(
+ optimized_metrics.counts.false_negatives,
+ reference_metrics.counts.false_negatives
+ );
+ assert!((optimized_metrics.precision - reference_metrics.precision).abs() < 1e-10);
+ assert!((optimized_metrics.recall - reference_metrics.recall).abs() < 1e-10);
+ }
+ }
+
+ #[test]
+ fn test_delta_chr_f_metrics_include_counts_and_rates() {
+ let original = "one two three";
+ let expected = "one three";
+ let actual = "one two four";
+
+ let metrics = delta_chr_f(original, expected, actual);
+
+ assert!(metrics.score > 20.0 && metrics.score < 40.0);
+ assert!(metrics.counts.true_positives > 0);
+ assert!(metrics.counts.false_positives > 0);
+ assert!(metrics.counts.false_negatives > 0);
+ assert!(metrics.precision > 0.0 && metrics.precision < 1.0);
+ assert!(metrics.recall > 0.0 && metrics.recall < 1.0);
+ assert_eq!(metrics.beta, CHR_F_BETA);
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use indoc::indoc;
+
+ fn cursor_on_line(one_based_line: u32) -> u32 {
+ one_based_line - 1
+ }
+
+ #[test]
+ fn test_delta_chr_f_perfect_match() {
+ let original = "fn main() { println!(\"Hello\");}";
+ let expected = "fn main() { println!(\"Hello, World!\");}";
+
+ let score = delta_chr_f(original, expected, expected).score;
+ assert!((score - 100.0).abs() < 1e-2);
+ }
+
+ #[test]
+ fn test_delta_chr_f_wrong_edit() {
+ // When the edit is wrong
+ let original = "one two three";
+ let expected = "one three"; // deleted "two "
+ let actual = "one two four"; // deleted "three", added "four"
+
+ // Then the score should be low
+ let score = delta_chr_f(original, expected, actual).score;
+ assert!(score > 20.0 && score < 40.0);
+ }
+
+ #[test]
+ fn test_delta_chr_f_partial_match() {
+ let original = "let x = 42;";
+ let expected = "let x = 100;";
+ let actual = "let x = 99;";
+
+ // We got the edit location right, but the replacement text is wrong.
+ // Deleted ngrams will match, bringing the score somewhere in the middle.
+ let score = delta_chr_f(original, expected, actual).score;
+ assert!(score > 40.0 && score < 60.0);
+ }
+
+ #[test]
+ fn test_delta_chr_f_missed_edit() {
+ // When predictions makes no changes
+ let original = "prefix old suffix";
+ let expected = "prefix new suffix";
+ let actual = "prefix old suffix"; // no change
+
+ // Then the score should be low (all expected changes are false negatives)
+ let score = delta_chr_f(original, expected, actual).score;
+ assert!(score < 20.0);
+ }
+
+ #[test]
+ fn test_delta_chr_f_extra_edit() {
+ // When adding unexpected content
+ let original = "helloworld";
+ let expected = "helloworld"; // no change expected
+ let actual = "helloextraworld"; // added "extra"
+
+ // Then the score should be low (all actual changes are false positives)
+ let score = delta_chr_f(original, expected, actual).score;
+ assert!(score < 20.0);
+ }
+
+ #[test]
+ fn test_delta_chr_f_no_changes() {
+ let text = "unchanged text";
+ let score = delta_chr_f(text, text, text).score;
+ assert!((score - 100.0).abs() < 1e-2);
+ }
+
+ #[test]
+ fn test_braces_disbalance() {
+ let text = "let x = { 1 + 2 };";
+ assert_eq!(braces_disbalance(text), 0);
+
+ let text = "let x = { 1 + 2";
+ assert_eq!(braces_disbalance(text), 1);
+
+ let text = "let x = { 1 + 2 )";
+ assert_eq!(braces_disbalance(text), 2);
+ }
+
+ #[test]
+ fn test_extract_changed_lines_from_diff() {
+ let diff = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+ fn main() {
+- println!("hello");
++ println!("world");
+ }"#;
+
+ let counts = extract_changed_lines_from_diff(diff);
+ assert_eq!(counts.get("- println!(\"hello\");"), Some(&1));
+ assert_eq!(counts.get("+ println!(\"world\");"), Some(&1));
+ assert_eq!(counts.len(), 2);
+ }
+
+ #[test]
+ fn test_extract_changed_lines_skips_headers() {
+ let diff = r#"diff --git a/file.rs b/file.rs
+index abc123..def456 100644
+--- a/file.rs
++++ b/file.rs
+@@ -1,2 +1,2 @@
+-old line
++new line"#;
+
+ let counts = extract_changed_lines_from_diff(diff);
+ assert_eq!(counts.get("-old line"), Some(&1));
+ assert_eq!(counts.get("+new line"), Some(&1));
+ assert_eq!(counts.len(), 2);
+ }
+
+ #[test]
+ fn test_exact_lines_match_perfect() {
+ let expected = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+ let actual = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ assert_eq!(metrics.true_positives, 4);
+ assert_eq!(metrics.false_positives, 0);
+ assert_eq!(metrics.false_negatives, 0);
+ assert!((metrics.precision() - 1.0).abs() < 1e-6);
+ assert!((metrics.recall() - 1.0).abs() < 1e-6);
+ assert!((metrics.f1() - 1.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_exact_lines_match_partial() {
+ let expected = r#"-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+ let actual = r#"-old line 1
++new line 1
++extra line"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ // TP: "-old line 1" and "+new line 1" (2)
+ // FP: "+extra line" (1)
+ // FN: "-old line 2" and "+new line 2" (2)
+ assert_eq!(metrics.true_positives, 2);
+ assert_eq!(metrics.false_positives, 1);
+ assert_eq!(metrics.false_negatives, 2);
+ }
+
+ #[test]
+ fn test_exact_lines_match_no_overlap() {
+ let expected = r#"-line a
++line b"#;
+
+ let actual = r#"-line x
++line y"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ assert_eq!(metrics.true_positives, 0);
+ assert_eq!(metrics.false_positives, 2);
+ assert_eq!(metrics.false_negatives, 2);
+ assert!((metrics.precision()).abs() < 1e-6);
+ assert!((metrics.recall()).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_exact_lines_match_duplicate_lines() {
+ let expected = r#"+line a
++line a
++line a"#;
+
+ let actual = r#"+line a
++line a"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ // Expected has 3 "+line a", actual has 2
+ // TP: 2, FN: 1, FP: 0
+ assert_eq!(metrics.true_positives, 2);
+ assert_eq!(metrics.false_positives, 0);
+ assert_eq!(metrics.false_negatives, 1);
+ }
+
+ #[test]
+ fn test_exact_lines_match_empty_patches() {
+ let metrics = exact_lines_match("", "");
+ assert_eq!(metrics.true_positives, 0);
+ assert_eq!(metrics.false_positives, 0);
+ assert_eq!(metrics.false_negatives, 0);
+ }
+
+ #[test]
+ fn test_is_editable_region_correct() {
+ let patch = indoc! {"
+ @@ -1,1 +1,1 @@
+ -context
+ -removed
+ -from the beginning of the file
+ import sys
+ +sys.exit(0)
+
+ "};
+ assert!(!is_editable_region_correct(patch));
+
+ let patch = indoc! {"
+ @@ -1,1 +1,1 @@
+ "};
+ assert!(is_editable_region_correct(patch));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_purely_whitespace_patch() {
+ let patch = indoc! {"
+ @@ -1,3 +1,4 @@
+ fn main() {
+ +
+ println!(\"hello\");
+ }
+ "};
+ assert!(has_isolated_whitespace_changes(patch, None));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_adjacent_to_real_change() {
+ let patch = indoc! {"
+ @@ -1,3 +1,4 @@
+ fn main() {
+ +
+ + let x = 1;
+ println!(\"hello\");
+ }
+ "};
+ assert!(!has_isolated_whitespace_changes(patch, None));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_no_whitespace_changes() {
+ let patch = indoc! {"
+ @@ -1,3 +1,3 @@
+ fn main() {
+ - println!(\"hello\");
+ + println!(\"world\");
+ }
+ "};
+ assert!(!has_isolated_whitespace_changes(patch, None));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_deletion() {
+ let patch = indoc! {"
+ @@ -1,4 +1,3 @@
+ fn main() {
+ -
+ println!(\"hello\");
+ }
+ "};
+ assert!(has_isolated_whitespace_changes(patch, None));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_mixed_groups() {
+ let patch = indoc! {"
+ @@ -1,7 +1,8 @@
+ fn main() {
+ +
+ let x = 1;
+ - let y = 2;
+ + let y = 3;
+
+ +
+ println!(\"hello\");
+ }
+ "};
+ assert!(has_isolated_whitespace_changes(patch, None));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_empty_patch() {
+ let patch = "";
+ assert!(!has_isolated_whitespace_changes(patch, None));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_skipped_on_cursor_line() {
+ // The addition of a blank line at new-file line 2 should be skipped
+ // because the cursor is on that line.
+ let patch = indoc! {"
+ @@ -1,3 +1,4 @@
+ fn main() {
+ +
+ println!(\"hello\");
+ }
+ "};
+ // New-file line 2 is the added blank line
+ let cursor = cursor_on_line(2);
+ assert!(!has_isolated_whitespace_changes(patch, Some(cursor)));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_not_skipped_when_cursor_on_different_line() {
+ // The blank line is at new-file line 2, but the cursor is on line 1.
+ let patch = indoc! {"
+ @@ -1,3 +1,4 @@
+ fn main() {
+ +
+ println!(\"hello\");
+ }
+ "};
+ let cursor = cursor_on_line(1);
+ assert!(has_isolated_whitespace_changes(patch, Some(cursor)));
+ }
+
+ #[test]
+ fn test_isolated_whitespace_deletion_not_skipped_by_cursor() {
+ // Deletions don't have a new-file line, so cursor can't suppress them.
+ let patch = indoc! {"
+ @@ -1,4 +1,3 @@
+ fn main() {
+ -
+ println!(\"hello\");
+ }
+ "};
+ let cursor = cursor_on_line(2);
+ assert!(has_isolated_whitespace_changes(patch, Some(cursor)));
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_real_world_rename() {
+ // Real-world patch that was reported as returning 0 tokens
+ let patch = "--- a/sip_call\\README.md\n+++ b/sip_call\\README.md\n@@ -1,1 +1,1 @@\n-# \n+# SIP Call\n";
+ let counts = count_patch_token_changes(patch);
+ // "# " vs "# SIP Call" — the "SIP" and "Call" tokens (and a whitespace token) are inserted
+ assert!(
+ counts.inserted_tokens > 0,
+ "expected inserted tokens > 0, got {}",
+ counts.inserted_tokens
+ );
+ assert_eq!(counts.deleted_tokens, 0);
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_real_world_expansion() {
+ // Real-world patch: single token expanded to multiple lines
+ let patch = "--- a/task1/src/app/app.html\n+++ b/task1/src/app/app.html\n@@ -1,7 +1,9 @@\n <style>\n- m\n+ main {\n+ \n+ }\n </style>\n \n <main>\n \n </main>\n";
+ let counts = count_patch_token_changes(patch);
+ assert!(
+ counts.inserted_tokens > 0,
+ "expected inserted tokens > 0, got {}",
+ counts.inserted_tokens
+ );
+ assert!(
+ counts.deleted_tokens > 0,
+ "expected deleted tokens > 0, got {}",
+ counts.deleted_tokens
+ );
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_simple_replacement() {
+ let patch = indoc! {"
+ @@ -1,3 +1,3 @@
+ fn main() {
+ - println!(\"hello\");
+ + println!(\"world\");
+ }
+ "};
+ let counts = count_patch_token_changes(patch);
+ assert_eq!(counts.deleted_tokens, 1, "deleted: \"hello\"");
+ assert_eq!(counts.inserted_tokens, 1, "inserted: \"world\"");
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_insertion_only() {
+ let patch = indoc! {"
+ @@ -1,2 +1,3 @@
+ fn main() {
+ + println!(\"hello\");
+ }
+ "};
+ let counts = count_patch_token_changes(patch);
+ assert_eq!(counts.deleted_tokens, 0);
+ assert!(counts.inserted_tokens > 0);
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_deletion_only() {
+ let patch = indoc! {"
+ @@ -1,3 +1,2 @@
+ fn main() {
+ - println!(\"hello\");
+ }
+ "};
+ let counts = count_patch_token_changes(patch);
+ assert!(counts.deleted_tokens > 0);
+ assert_eq!(counts.inserted_tokens, 0);
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_empty_patch() {
+ let patch = "";
+ let counts = count_patch_token_changes(patch);
+ assert_eq!(counts.deleted_tokens, 0);
+ assert_eq!(counts.inserted_tokens, 0);
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_multiple_hunks() {
+ let patch = indoc! {"
+ @@ -1,3 +1,3 @@
+ fn main() {
+ - let x = 1;
+ + let x = 2;
+ }
+ @@ -10,3 +10,3 @@
+ fn other() {
+ - let y = 3;
+ + let y = 4;
+ }
+ "};
+ let counts = count_patch_token_changes(patch);
+ assert_eq!(counts.deleted_tokens, 2, "deleted: \"1\" and \"3\"");
+ assert_eq!(counts.inserted_tokens, 2, "inserted: \"2\" and \"4\"");
+ }
+
+ #[test]
+ fn test_count_patch_token_changes_multiword_change() {
+ let patch = indoc! {"
+ @@ -1,1 +1,1 @@
+ -hello world foo
+ +hello bar baz
+ "};
+ let counts = count_patch_token_changes(patch);
+ // "world" and "foo" deleted, "bar" and "baz" inserted
+ // (whitespace tokens between them may also count)
+ assert!(counts.deleted_tokens >= 2);
+ assert!(counts.inserted_tokens >= 2);
+ }
+
+ #[test]
+ fn test_whitespace_collapse() {
+ let text = "abc \n\n\n 123";
+ let collapsed = collapse_whitespace(text.chars());
+ assert_eq!(
+ collapsed,
+ vec!['a', 'b', 'c', ' ', '\n', ' ', '1', '2', '3']
+ );
+ }
+}
@@ -0,0 +1,648 @@
+use std::ops::Range;
+use std::path::Path;
+use std::sync::Arc;
+
+use language::{char_diff, text_diff};
+use zeta_prompt::udiff::apply_diff_to_string;
+
+fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String {
+ let hunks = parse_diff_hunks(diff_str);
+ let mut result = text.to_string();
+
+ for hunk in hunks {
+ let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk));
+ if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) {
+ result = updated;
+ }
+ }
+
+ result
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+struct ParsedHunk {
+ old_start: u32,
+ old_count: u32,
+ new_start: u32,
+ new_count: u32,
+ lines: Vec<HunkLine>,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+enum HunkLine {
+ Context(String),
+ Addition(String),
+ Deletion(String),
+}
+
+fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
+ let line = line.strip_prefix("@@ -")?;
+ let (old_part, rest) = line.split_once(' ')?;
+ let rest = rest.strip_prefix('+')?;
+ let (new_part, _) = rest.split_once(" @@")?;
+
+ let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') {
+ (start.parse().ok()?, count.parse().ok()?)
+ } else {
+ (old_part.parse().ok()?, 1)
+ };
+
+ let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') {
+ (start.parse().ok()?, count.parse().ok()?)
+ } else {
+ (new_part.parse().ok()?, 1)
+ };
+
+ Some((old_start, old_count, new_start, new_count))
+}
+
+fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
+ let mut hunks = Vec::new();
+ let mut current_hunk: Option<ParsedHunk> = None;
+
+ for line in diff.lines() {
+ if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
+ if let Some(hunk) = current_hunk.take() {
+ hunks.push(hunk);
+ }
+ current_hunk = Some(ParsedHunk {
+ old_start,
+ old_count,
+ new_start,
+ new_count,
+ lines: Vec::new(),
+ });
+ } else if let Some(ref mut hunk) = current_hunk {
+ if let Some(stripped) = line.strip_prefix('+') {
+ hunk.lines.push(HunkLine::Addition(stripped.to_string()));
+ } else if let Some(stripped) = line.strip_prefix('-') {
+ hunk.lines.push(HunkLine::Deletion(stripped.to_string()));
+ } else if let Some(stripped) = line.strip_prefix(' ') {
+ hunk.lines.push(HunkLine::Context(stripped.to_string()));
+ } else if line.is_empty() {
+ hunk.lines.push(HunkLine::Context(String::new()));
+ }
+ }
+ }
+
+ if let Some(hunk) = current_hunk {
+ hunks.push(hunk);
+ }
+
+ hunks
+}
+
+fn format_hunk(hunk: &ParsedHunk) -> String {
+ let mut result = format!(
+ "@@ -{},{} +{},{} @@\n",
+ hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count
+ );
+ for line in &hunk.lines {
+ match line {
+ HunkLine::Context(text) => {
+ result.push(' ');
+ result.push_str(text);
+ result.push('\n');
+ }
+ HunkLine::Addition(text) => {
+ result.push('+');
+ result.push_str(text);
+ result.push('\n');
+ }
+ HunkLine::Deletion(text) => {
+ result.push('-');
+ result.push_str(text);
+ result.push('\n');
+ }
+ }
+ }
+ result
+}
+
+fn filter_diff_hunks_by_excerpt(
+ diff: &str,
+ excerpt_start_row: u32,
+ excerpt_row_count: u32,
+) -> (String, i32) {
+ let hunks = parse_diff_hunks(diff);
+ let excerpt_start_0based = excerpt_start_row;
+ let excerpt_end_0based = excerpt_start_row + excerpt_row_count;
+
+ let mut filtered_hunks = Vec::new();
+ let mut cumulative_line_offset: i32 = 0;
+
+ for hunk in hunks {
+ let hunk_start_0based = hunk.new_start.saturating_sub(1);
+ let hunk_end_0based = hunk_start_0based + hunk.new_count;
+
+ let additions: i32 = hunk
+ .lines
+ .iter()
+ .filter(|l| matches!(l, HunkLine::Addition(_)))
+ .count() as i32;
+ let deletions: i32 = hunk
+ .lines
+ .iter()
+ .filter(|l| matches!(l, HunkLine::Deletion(_)))
+ .count() as i32;
+ let hunk_line_delta = additions - deletions;
+
+ if hunk_end_0based <= excerpt_start_0based {
+ cumulative_line_offset += hunk_line_delta;
+ continue;
+ }
+
+ if hunk_start_0based >= excerpt_end_0based {
+ continue;
+ }
+
+ let mut filtered_lines = Vec::new();
+ let mut current_row_0based = hunk_start_0based;
+ let mut filtered_old_count = 0u32;
+ let mut filtered_new_count = 0u32;
+ let mut first_included_row: Option<u32> = None;
+
+ for line in &hunk.lines {
+ match line {
+ HunkLine::Context(text) => {
+ if current_row_0based >= excerpt_start_0based
+ && current_row_0based < excerpt_end_0based
+ {
+ if first_included_row.is_none() {
+ first_included_row = Some(current_row_0based);
+ }
+ filtered_lines.push(HunkLine::Context(text.clone()));
+ filtered_old_count += 1;
+ filtered_new_count += 1;
+ }
+ current_row_0based += 1;
+ }
+ HunkLine::Addition(text) => {
+ if current_row_0based >= excerpt_start_0based
+ && current_row_0based < excerpt_end_0based
+ {
+ if first_included_row.is_none() {
+ first_included_row = Some(current_row_0based);
+ }
+ filtered_lines.push(HunkLine::Addition(text.clone()));
+ filtered_new_count += 1;
+ }
+ current_row_0based += 1;
+ }
+ HunkLine::Deletion(text) => {
+ if current_row_0based >= excerpt_start_0based
+ && current_row_0based < excerpt_end_0based
+ {
+ if first_included_row.is_none() {
+ first_included_row = Some(current_row_0based);
+ }
+ filtered_lines.push(HunkLine::Deletion(text.clone()));
+ filtered_old_count += 1;
+ }
+ }
+ }
+ }
+
+ if !filtered_lines.is_empty() {
+ let first_row = first_included_row.unwrap_or(excerpt_start_0based);
+ let new_start_1based = (first_row - excerpt_start_0based) + 1;
+
+ filtered_hunks.push(ParsedHunk {
+ old_start: new_start_1based,
+ old_count: filtered_old_count,
+ new_start: new_start_1based,
+ new_count: filtered_new_count,
+ lines: filtered_lines,
+ });
+ }
+
+ cumulative_line_offset += hunk_line_delta;
+ }
+
+ let mut result = String::new();
+ for hunk in &filtered_hunks {
+ result.push_str(&format_hunk(hunk));
+ }
+
+ (result, cumulative_line_offset)
+}
+
+fn compute_excerpt_aware_reversal_overlap(
+ edit_history_diffs: &[&str],
+ excerpt_content: &str,
+ excerpt_start_row: u32,
+ predicted_content: &str,
+) -> ReversalOverlap {
+ let mut current_content = excerpt_content.to_string();
+ let mut current_excerpt_start_row = excerpt_start_row;
+
+ for diff in edit_history_diffs.iter().rev() {
+ if diff.is_empty() {
+ continue;
+ }
+
+ let current_row_count = current_content.lines().count() as u32;
+ let (filtered_diff, _line_offset) =
+ filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1));
+
+ if filtered_diff.is_empty() {
+ let hunks = parse_diff_hunks(diff);
+ for hunk in hunks {
+ let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
+ if hunk_end <= current_excerpt_start_row {
+ let additions: u32 = hunk
+ .lines
+ .iter()
+ .filter(|l| matches!(l, HunkLine::Addition(_)))
+ .count() as u32;
+ let deletions: u32 = hunk
+ .lines
+ .iter()
+ .filter(|l| matches!(l, HunkLine::Deletion(_)))
+ .count() as u32;
+ if additions >= deletions {
+ current_excerpt_start_row =
+ current_excerpt_start_row.saturating_sub(additions - deletions);
+ } else {
+ current_excerpt_start_row += deletions - additions;
+ }
+ }
+ }
+ continue;
+ }
+
+ let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff));
+ match apply_diff_to_string(&reversed, ¤t_content) {
+ Ok(updated) => {
+ current_content = updated;
+ }
+ Err(_) => {
+ continue;
+ }
+ }
+
+ let hunks = parse_diff_hunks(diff);
+ for hunk in hunks {
+ let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
+ if hunk_end <= current_excerpt_start_row {
+ let additions: u32 = hunk
+ .lines
+ .iter()
+ .filter(|l| matches!(l, HunkLine::Addition(_)))
+ .count() as u32;
+ let deletions: u32 = hunk
+ .lines
+ .iter()
+ .filter(|l| matches!(l, HunkLine::Deletion(_)))
+ .count() as u32;
+ if additions >= deletions {
+ current_excerpt_start_row =
+ current_excerpt_start_row.saturating_sub(additions - deletions);
+ } else {
+ current_excerpt_start_row += deletions - additions;
+ }
+ }
+ }
+ }
+
+ compute_reversal_overlap(¤t_content, excerpt_content, predicted_content)
+}
+
+fn reverse_diff(diff: &str) -> String {
+ let mut result: String = diff
+ .lines()
+ .map(|line| {
+ if line.starts_with("--- ") {
+ line.replacen("--- ", "+++ ", 1)
+ } else if line.starts_with("+++ ") {
+ line.replacen("+++ ", "--- ", 1)
+ } else if line.starts_with('+') && !line.starts_with("+++") {
+ format!("-{}", &line[1..])
+ } else if line.starts_with('-') && !line.starts_with("---") {
+ format!("+{}", &line[1..])
+ } else {
+ line.to_string()
+ }
+ })
+ .collect::<Vec<_>>()
+ .join("\n");
+ if diff.ends_with('\n') {
+ result.push('\n');
+ }
+ result
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+struct GranularEdit {
+ range: Range<usize>,
+ old_text: String,
+ new_text: String,
+}
+
+fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
+ text_diff(old_text, new_text)
+ .into_iter()
+ .map(|(range, new_text)| GranularEdit {
+ old_text: old_text[range.clone()].to_string(),
+ range,
+ new_text: new_text.to_string(),
+ })
+ .collect()
+}
+
+#[derive(Debug, Clone)]
+struct HistoryAdditionRange {
+ range_in_current: Range<usize>,
+}
+
+#[derive(Debug, Clone)]
+struct HistoryDeletionRange {
+ deleted_text: String,
+ position_in_current: usize,
+}
+
+fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
+ let mut result = Vec::new();
+ let mut offset_delta: isize = 0;
+
+ for edit in history_edits {
+ if !edit.new_text.is_empty() {
+ let new_start = (edit.range.start as isize + offset_delta) as usize;
+ let new_end = new_start + edit.new_text.len();
+ result.push(HistoryAdditionRange {
+ range_in_current: new_start..new_end,
+ });
+ }
+
+ offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
+ }
+
+ result
+}
+
+fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
+ let mut result = Vec::new();
+ let mut offset_delta: isize = 0;
+
+ for edit in history_edits {
+ if !edit.old_text.is_empty() {
+ let position_in_current = (edit.range.start as isize + offset_delta) as usize;
+ result.push(HistoryDeletionRange {
+ deleted_text: edit.old_text.clone(),
+ position_in_current,
+ });
+ }
+
+ offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
+ }
+
+ result
+}
+
+#[derive(Debug, Clone, Default, PartialEq, Eq)]
+struct ReversalOverlap {
+ chars_reversing_user_edits: usize,
+ total_chars_in_prediction: usize,
+}
+
+impl ReversalOverlap {
+ fn ratio(&self) -> f32 {
+ if self.total_chars_in_prediction == 0 {
+ 0.0
+ } else {
+ self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
+ }
+ }
+}
+
+/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
+/// or where `new_text` appears as a subsequence within `old_text` (reduction).
+///
+/// For extensions: when the user's text is preserved (in order) within the prediction,
+/// we only count the newly inserted characters, not the preserved ones.
+/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
+/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
+///
+/// For reductions: when the prediction's text is preserved (in order) within the original,
+/// we only count the deleted characters, not the preserved ones.
+/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
+fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
+ edits
+ .into_iter()
+ .flat_map(|edit| {
+ if edit.old_text.is_empty() || edit.new_text.is_empty() {
+ return vec![edit];
+ }
+
+ // Use character-wise diff to find exact byte ranges of changes
+ let char_edits = char_diff(&edit.old_text, &edit.new_text);
+
+ let all_deletions = !char_edits.is_empty()
+ && char_edits
+ .iter()
+ .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
+ let all_insertions = !char_edits.is_empty()
+ && char_edits
+ .iter()
+ .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
+ if all_deletions || all_insertions {
+ return char_edits
+ .into_iter()
+ .map(|(range, replacement)| GranularEdit {
+ range: edit.range.start + range.start..edit.range.start + range.end,
+ old_text: edit.old_text[range].to_string(),
+ new_text: replacement.to_string(),
+ })
+ .collect();
+ }
+
+ // Otherwise, keep the original edit (mixed changes)
+ vec![edit]
+ })
+ .collect()
+}
+
+fn compute_reversal_overlap(
+ original_content: &str,
+ current_content: &str,
+ predicted_content: &str,
+) -> ReversalOverlap {
+ let history_edits =
+ normalize_extension_edits(compute_granular_edits(original_content, current_content));
+ let prediction_edits =
+ normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
+
+ let history_addition_ranges = compute_history_addition_ranges(&history_edits);
+ let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
+
+ let reversed_additions =
+ compute_reversed_additions(&history_addition_ranges, &prediction_edits);
+ let restored_deletions =
+ compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
+
+ let total_chars_in_prediction: usize = prediction_edits
+ .iter()
+ .map(|e| e.new_text.chars().count() + e.old_text.chars().count())
+ .sum();
+
+ ReversalOverlap {
+ chars_reversing_user_edits: reversed_additions + restored_deletions,
+ total_chars_in_prediction,
+ }
+}
+
+fn compute_reversed_additions(
+ history_addition_ranges: &[HistoryAdditionRange],
+ prediction_edits: &[GranularEdit],
+) -> usize {
+ let mut reversed_chars = 0;
+
+ for pred_edit in prediction_edits {
+ for history_addition in history_addition_ranges {
+ let overlap_start = pred_edit
+ .range
+ .start
+ .max(history_addition.range_in_current.start);
+ let overlap_end = pred_edit
+ .range
+ .end
+ .min(history_addition.range_in_current.end);
+
+ if overlap_start < overlap_end {
+ let relative_start = overlap_start - pred_edit.range.start;
+ let relative_end = overlap_end - pred_edit.range.start;
+ let overlap_text = &pred_edit.old_text[relative_start..relative_end];
+ reversed_chars += overlap_text.chars().count();
+ }
+ }
+ }
+
+ reversed_chars
+}
+
+fn compute_restored_deletions(
+ history_deletion_ranges: &[HistoryDeletionRange],
+ prediction_edits: &[GranularEdit],
+) -> usize {
+ let mut restored = 0;
+
+ for pred_edit in prediction_edits {
+ if pred_edit.new_text.is_empty() {
+ continue;
+ }
+
+ for deletion in history_deletion_ranges {
+ if pred_edit.range.contains(&deletion.position_in_current)
+ || deletion.position_in_current == pred_edit.range.start
+ {
+ restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text);
+ }
+ }
+ }
+
+ restored
+}
+
+fn compute_lcs_length(a: &str, b: &str) -> usize {
+ let a_chars: Vec<char> = a.chars().collect();
+ let b_chars: Vec<char> = b.chars().collect();
+ let m = a_chars.len();
+ let n = b_chars.len();
+
+ if m == 0 || n == 0 {
+ return 0;
+ }
+
+ let mut prev = vec![0; n + 1];
+ let mut curr = vec![0; n + 1];
+
+ for i in 1..=m {
+ for j in 1..=n {
+ if a_chars[i - 1] == b_chars[j - 1] {
+ curr[j] = prev[j - 1] + 1;
+ } else {
+ curr[j] = prev[j].max(curr[j - 1]);
+ }
+ }
+ std::mem::swap(&mut prev, &mut curr);
+ curr.fill(0);
+ }
+
+ prev[n]
+}
+
+fn filter_edit_history_by_path<'a>(
+ edit_history: &'a [Arc<zeta_prompt::Event>],
+ cursor_path: &std::path::Path,
+) -> Vec<&'a zeta_prompt::Event> {
+ edit_history
+ .iter()
+ .filter(|event| match event.as_ref() {
+ zeta_prompt::Event::BufferChange { path, .. } => {
+ let event_path = path.as_ref();
+ if event_path == cursor_path {
+ return true;
+ }
+ let stripped = event_path
+ .components()
+ .skip(1)
+ .collect::<std::path::PathBuf>();
+ stripped == cursor_path
+ }
+ })
+ .map(|arc| arc.as_ref())
+ .collect()
+}
+
+fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
+ match event {
+ zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
+ }
+}
+
+fn is_predicted_event(event: &zeta_prompt::Event) -> bool {
+ match event {
+ zeta_prompt::Event::BufferChange { predicted, .. } => *predicted,
+ }
+}
+
+pub fn compute_prediction_reversal_ratio_from_history(
+ current_content: &str,
+ edit_history: &[Arc<zeta_prompt::Event>],
+ excerpt_start_row: Option<u32>,
+ predicted_content: &str,
+ cursor_path: &Path,
+) -> f32 {
+ let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
+
+ let most_recent = match relevant_events.last() {
+ Some(event) if !is_predicted_event(event) => *event,
+ _ => return 0.0,
+ };
+
+ let diff = extract_diff_from_event(most_recent);
+ if diff.is_empty() {
+ return 0.0;
+ }
+
+ if let Some(excerpt_start_row) = excerpt_start_row {
+ let diffs = vec![diff];
+ let overlap = compute_excerpt_aware_reversal_overlap(
+ &diffs,
+ current_content,
+ excerpt_start_row,
+ predicted_content,
+ );
+ return overlap.ratio();
+ }
+
+ let reversed = reverse_diff(diff);
+ let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
+ let original_content = match apply_diff_to_string(&with_headers, current_content) {
+ Ok(updated_content) => updated_content,
+ Err(_) => apply_diff_to_string_lenient(&reversed, current_content),
+ };
+
+ let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
+ overlap.ratio()
+}
@@ -0,0 +1,27 @@
+pub fn count_tree_sitter_errors<'a>(nodes: impl Iterator<Item = tree_sitter::Node<'a>>) -> usize {
+ let mut total_count: usize = 0;
+ for node in nodes {
+ let mut cursor = node.walk();
+ 'node: loop {
+ let current = cursor.node();
+ if current.is_error() || current.is_missing() {
+ total_count += 1;
+ }
+ if current.has_error() && cursor.goto_first_child() {
+ continue;
+ }
+ if cursor.goto_next_sibling() {
+ continue;
+ }
+ loop {
+ if !cursor.goto_parent() {
+ break 'node;
+ }
+ if cursor.goto_next_sibling() {
+ continue;
+ }
+ }
+ }
+ }
+ total_count
+}