From edac6a7b7a2d7979fe1d4a1444acf2585b3ceb38 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Tue, 14 Apr 2026 21:04:10 +0300 Subject: [PATCH] Move edit prediction metrics into shared crate (#53912) Release Notes: - N/A --------- Co-authored-by: Ben Kunkle --- Cargo.lock | 15 + Cargo.toml | 2 + crates/edit_prediction/Cargo.toml | 1 + crates/edit_prediction/src/metrics.rs | 16 +- .../src/metrics/tree_sitter.rs | 88 - crates/edit_prediction_cli/Cargo.toml | 1 + crates/edit_prediction_cli/src/metrics.rs | 1319 +---------- .../src/reversal_tracking.rs | 2013 +---------------- crates/edit_prediction_metrics/Cargo.toml | 23 + crates/edit_prediction_metrics/LICENSE-GPL | 1 + .../src/edit_prediction_metrics.rs | 24 + .../src}/kept_rate.rs | 7 +- .../src/patch_metrics.rs | 1451 ++++++++++++ .../edit_prediction_metrics/src/reversal.rs | 648 ++++++ .../src}/tokenize.rs | 0 .../src/tree_sitter.rs | 27 + 16 files changed, 2232 insertions(+), 3404 deletions(-) delete mode 100644 crates/edit_prediction/src/metrics/tree_sitter.rs create mode 100644 crates/edit_prediction_metrics/Cargo.toml create mode 120000 crates/edit_prediction_metrics/LICENSE-GPL create mode 100644 crates/edit_prediction_metrics/src/edit_prediction_metrics.rs rename crates/{edit_prediction/src/metrics => edit_prediction_metrics/src}/kept_rate.rs (99%) create mode 100644 crates/edit_prediction_metrics/src/patch_metrics.rs create mode 100644 crates/edit_prediction_metrics/src/reversal.rs rename crates/{edit_prediction/src/metrics => edit_prediction_metrics/src}/tokenize.rs (100%) create mode 100644 crates/edit_prediction_metrics/src/tree_sitter.rs diff --git a/Cargo.lock b/Cargo.lock index aaa953d50fe3962d3a502342890caa1bae42c25f..23bd6699e5f912dff7963fed337abc2a57c8f6d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5144,6 +5144,7 @@ dependencies = [ "ctor", "db", "edit_prediction_context", + "edit_prediction_metrics", "edit_prediction_types", "feature_flags", "fs", @@ -5206,6 +5207,7 @@ dependencies = [ "debug_adapter_extension", "dirs", "edit_prediction", + "edit_prediction_metrics", "extension", "flate2", "fs", @@ -5280,6 +5282,19 @@ dependencies = [ "zeta_prompt", ] +[[package]] +name = "edit_prediction_metrics" +version = "0.1.0" +dependencies = [ + "indoc", + "language", + "pretty_assertions", + "serde", + "similar", + "tree-sitter", + "zeta_prompt", +] + [[package]] name = "edit_prediction_types" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index b182bb6c27037fd02bbbf2013dcda4edfc2fcac2..7c3d79c223a6c0e0b9000de533237465356a6936 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ members = [ "crates/edit_prediction", "crates/edit_prediction_cli", "crates/edit_prediction_context", + "crates/edit_prediction_metrics", "crates/edit_prediction_types", "crates/edit_prediction_ui", "crates/editor", @@ -480,6 +481,7 @@ zed_actions = { path = "crates/zed_actions" } zed_credentials_provider = { path = "crates/zed_credentials_provider" } zed_env_vars = { path = "crates/zed_env_vars" } edit_prediction = { path = "crates/edit_prediction" } +edit_prediction_metrics = { path = "crates/edit_prediction_metrics" } zeta_prompt = { path = "crates/zeta_prompt" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index ae6def3686c727c18607b5cc6c135e4a0d16613d..9e4805938f21d8c599bcc3bc513df1dfc8e241d0 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -31,6 +31,7 @@ credentials_provider.workspace = true db.workspace = true edit_prediction_types.workspace = true edit_prediction_context.workspace = true +edit_prediction_metrics.workspace = true feature_flags.workspace = true fs.workspace = true futures.workspace = true diff --git a/crates/edit_prediction/src/metrics.rs b/crates/edit_prediction/src/metrics.rs index 20abd683a53fa34397a32a24abb0b49f553c0895..643c7c44d831c3bdcff591eac2e348a5cb8e9363 100644 --- a/crates/edit_prediction/src/metrics.rs +++ b/crates/edit_prediction/src/metrics.rs @@ -1,10 +1,8 @@ -mod kept_rate; -mod tokenize; -mod tree_sitter; +pub use edit_prediction_metrics::KeptRateResult; -pub use kept_rate::KeptRateResult; -#[cfg(test)] -pub use kept_rate::TokenAnnotation; -pub use kept_rate::compute_kept_rate; -pub(crate) use tokenize::tokenize; -pub use tree_sitter::count_tree_sitter_errors; +pub use edit_prediction_metrics::compute_kept_rate; +use language::SyntaxLayer; + +pub fn count_tree_sitter_errors<'a>(layers: impl Iterator>) -> usize { + edit_prediction_metrics::count_tree_sitter_errors(layers.map(|layer| layer.node())) +} diff --git a/crates/edit_prediction/src/metrics/tree_sitter.rs b/crates/edit_prediction/src/metrics/tree_sitter.rs deleted file mode 100644 index 1bb200289ca5007fd4711f0cb46c80ea1153bf28..0000000000000000000000000000000000000000 --- a/crates/edit_prediction/src/metrics/tree_sitter.rs +++ /dev/null @@ -1,88 +0,0 @@ -use language::SyntaxLayer; - -pub fn count_tree_sitter_errors<'a>(layers: impl Iterator>) -> usize { - let mut total_count: usize = 0; - for layer in layers { - let node = layer.node(); - let mut cursor = node.walk(); - 'layer: loop { - let current = cursor.node(); - if current.is_error() || current.is_missing() { - total_count += 1; - } - if current.has_error() && cursor.goto_first_child() { - continue; - } - if cursor.goto_next_sibling() { - continue; - } - loop { - if !cursor.goto_parent() { - break 'layer; - } - if cursor.goto_next_sibling() { - continue; - } - } - } - } - total_count -} - -#[cfg(test)] -mod tests { - use std::ops::Range; - - use super::count_tree_sitter_errors; - use gpui::{AppContext as _, TestAppContext}; - use language::{Buffer, BufferSnapshot, rust_lang}; - - fn error_count_in_range(edited_buffer_snapshot: &BufferSnapshot, range: Range) -> usize { - let layers = edited_buffer_snapshot.syntax_layers_for_range(range, true); - count_tree_sitter_errors(layers) - } - - fn rust_snapshot(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); - while buffer.read_with(cx, |buffer, _| buffer.is_parsing()) { - cx.run_until_parked(); - } - buffer.read_with(cx, |buffer, _| buffer.snapshot()) - } - - #[gpui::test] - async fn counts_no_errors_for_valid_rust(cx: &mut TestAppContext) { - let text = "fn helper(value: usize) -> usize {\n value + 1\n}\n"; - let snapshot = rust_snapshot(text, cx); - - assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 0); - } - - #[gpui::test] - async fn counts_errors_for_invalid_rust(cx: &mut TestAppContext) { - let text = "fn helper(value: usize) -> usize {\n let total = ;\n total\n}\n"; - let snapshot = rust_snapshot(text, cx); - - assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 1); - } - - #[gpui::test] - async fn counts_no_errors_for_subrange_of_valid_rust(cx: &mut TestAppContext) { - let text = "fn first() -> usize {\n let value = 1;\n value + 1\n}\n"; - let snapshot = rust_snapshot(text, cx); - let body_start = text.find("let value").unwrap(); - let body_end = body_start + "let value = 1;".len(); - - assert_eq!(error_count_in_range(&snapshot, body_start..body_end), 0); - } - - #[gpui::test] - async fn counts_errors_for_subrange_of_invalid_rust(cx: &mut TestAppContext) { - let text = "fn second() -> usize {\n let broken = ;\n broken\n}\n"; - let snapshot = rust_snapshot(text, cx); - let error_start = text.find("let broken = ;").unwrap(); - let error_end = error_start + "let broken = ;".len(); - - assert_eq!(error_count_in_range(&snapshot, error_start..error_end), 1); - } -} diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 8aa4ff63aca1d9b6f418924c4ccc232d368d5a69..97db020a552e69ff977b69e72febab7d08325e3e 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -61,6 +61,7 @@ terminal_view.workspace = true util.workspace = true watch.workspace = true edit_prediction = { workspace = true, features = ["cli-support"] } +edit_prediction_metrics.workspace = true telemetry_events.workspace = true wasmtime.workspace = true zeta_prompt.workspace = true diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index b28edbb7eb12929ee883eed29a9ef775e100281f..916d1498e6e1aea62ab4ff6e4ac90af627f8c5e0 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,1301 +1,24 @@ -use collections::HashMap; +#![allow(unused_imports)] + +use crate::example::ActualCursor; + +pub use edit_prediction_metrics::ClassificationMetrics; +pub use edit_prediction_metrics::Counts; +pub use edit_prediction_metrics::DeltaChrFMetrics; +pub use edit_prediction_metrics::KeptRateResult; +pub use edit_prediction_metrics::TokenChangeCounts; +pub use edit_prediction_metrics::braces_disbalance; +pub use edit_prediction_metrics::compute_kept_rate; +pub use edit_prediction_metrics::count_patch_token_changes; +pub use edit_prediction_metrics::delta_chr_f; +pub use edit_prediction_metrics::delta_chr_f_beta; +pub use edit_prediction_metrics::exact_lines_match; +pub use edit_prediction_metrics::extract_changed_lines_from_diff; +pub use edit_prediction_metrics::is_editable_region_correct; -use crate::{ - example::ActualCursor, - reorder_patch::{Patch, PatchLine}, - word_diff::{DiffOp, diff_tokens, tokenize}, -}; - -pub type Counts = HashMap; -type CountsDelta = HashMap; - -/// Context characters needed on each side of a change to capture all affected n-grams -const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1; - -#[derive(Default, Debug, Clone)] -pub struct ClassificationMetrics { - pub true_positives: usize, - pub false_positives: usize, - pub false_negatives: usize, -} - -impl ClassificationMetrics { - pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { - let mut true_positives = 0; - let mut false_positives = 0; - let mut false_negatives = 0; - - for (ngram, &expected_count) in expected { - let actual_count = *actual.get(ngram).unwrap_or(&0); - if actual_count > expected_count { - false_positives += actual_count - expected_count; - } else { - false_negatives += expected_count - actual_count; - } - true_positives += expected_count.min(actual_count); - } - - for (ngram, &actual_count) in actual { - if !expected.contains_key(ngram) { - false_positives += actual_count; - } - } - - ClassificationMetrics { - true_positives, - false_positives, - false_negatives, - } - } - - pub fn accumulate(&mut self, other: &ClassificationMetrics) { - self.true_positives += other.true_positives; - self.false_positives += other.false_positives; - self.false_negatives += other.false_negatives; - } - - pub fn precision(&self) -> f64 { - if self.true_positives + self.false_positives == 0 { - 0.0 - } else { - self.true_positives as f64 / (self.true_positives + self.false_positives) as f64 - } - } - - pub fn recall(&self) -> f64 { - if self.true_positives + self.false_negatives == 0 { - 0.0 - } else { - self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64 - } - } - - pub fn f1(&self) -> f64 { - let precision = self.precision(); - let recall = self.recall(); - if precision + recall == 0.0 { - 0.0 - } else { - 2.0 * precision * recall / (precision + recall) - } - } -} - -enum ChrfWhitespace { - /// Preserve whitespace as-is - #[allow(unused)] - Unchanged, - - /// Ignore all whitespace differences - #[allow(unused)] - Ignore, - - /// Collapse whitespace into single spaces - Collapse, -} - -const CHR_F_CHAR_ORDER: usize = 6; -const CHR_F_BETA: f64 = 0.5; -const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse; - -pub fn delta_chr_f_beta() -> f64 { - CHR_F_BETA -} - -#[derive(Default, Debug, Clone)] -pub struct DeltaChrFMetrics { - pub score: f64, - pub beta: f64, - pub counts: ClassificationMetrics, - pub precision: f64, - pub recall: f64, -} - -/// Computes delta-chrF metrics that compare two sets of edits. -/// -/// This metric works by: -/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual -/// 2. Comparing these deltas to measure how well actual edits match expected edits -/// -/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match -/// the expected edits. -pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics { - if original == expected && expected == actual { - return DeltaChrFMetrics { - score: 100.0, - beta: CHR_F_BETA, - precision: 1.0, - recall: 1.0, - ..DeltaChrFMetrics::default() - }; - } - - let orig_chars: Vec = filter_whitespace_chars(original); - let exp_chars: Vec = filter_whitespace_chars(expected); - let act_chars: Vec = filter_whitespace_chars(actual); - - // Find the changed regions between original→expected and original→actual - // We only need to compute n-grams on these regions (plus context for boundary n-grams) - let (orig_for_exp, exp_region) = extract_changed_regions(&orig_chars, &exp_chars); - let (orig_for_act, act_region) = extract_changed_regions(&orig_chars, &act_chars); - - let mut total_precision = 0.0; - let mut total_recall = 0.0; - let mut total_counts = ClassificationMetrics::default(); - - for order in 1..=CHR_F_CHAR_ORDER { - let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order); - let exp_ngrams = count_ngrams_from_chars(&exp_region, order); - let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp); - - let orig_ngrams_for_act = count_ngrams_from_chars(&orig_for_act, order); - let act_ngrams = count_ngrams_from_chars(&act_region, order); - let actual_delta = compute_ngram_delta(&act_ngrams, &orig_ngrams_for_act); - - if expected_delta.is_empty() && actual_delta.is_empty() { - total_precision += 1.0; - total_recall += 1.0; - continue; - } - - let expected_counts = ngram_delta_to_counts(&expected_delta); - let actual_counts = ngram_delta_to_counts(&actual_delta); - - let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); - total_precision += counts.precision(); - total_recall += counts.recall(); - total_counts.accumulate(&counts); - } - - let average_precision = total_precision / CHR_F_CHAR_ORDER as f64; - let average_recall = total_recall / CHR_F_CHAR_ORDER as f64; - let score = if average_precision + average_recall == 0.0 { - 0.0 - } else { - (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall - / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall) - * 100.0 - }; - - DeltaChrFMetrics { - score, - beta: CHR_F_BETA, - counts: total_counts, - precision: average_precision, - recall: average_recall, - } -} - -/// Reference implementation of delta-chrF metrics (original, non-optimized version). -/// Used for testing that the optimized version produces identical results. -#[cfg(test)] -fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics { - if original == expected && expected == actual { - return DeltaChrFMetrics { - score: 100.0, - beta: CHR_F_BETA, - precision: 1.0, - recall: 1.0, - ..DeltaChrFMetrics::default() - }; - } - - let original_ngrams = chr_f_ngram_counts(original); - let expected_ngrams = chr_f_ngram_counts(expected); - let actual_ngrams = chr_f_ngram_counts(actual); - - let mut total_precision = 0.0; - let mut total_recall = 0.0; - let mut total_counts = ClassificationMetrics::default(); - - for order in 0..CHR_F_CHAR_ORDER { - let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]); - let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]); - - if expected_delta.is_empty() && actual_delta.is_empty() { - total_precision += 1.0; - total_recall += 1.0; - continue; - } - - let expected_counts = ngram_delta_to_counts(&expected_delta); - let actual_counts = ngram_delta_to_counts(&actual_delta); - - let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); - total_precision += counts.precision(); - total_recall += counts.recall(); - total_counts.accumulate(&counts); - } - - let average_precision = total_precision / CHR_F_CHAR_ORDER as f64; - let average_recall = total_recall / CHR_F_CHAR_ORDER as f64; - let score = if average_precision + average_recall == 0.0 { - 0.0 - } else { - (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall - / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall) - * 100.0 - }; - - DeltaChrFMetrics { - score, - beta: CHR_F_BETA, - counts: total_counts, - precision: average_precision, - recall: average_recall, - } -} - -/// Filter whitespace from a string and return as Vec -fn filter_whitespace_chars(text: &str) -> Vec { - match CHR_F_WHITESPACE { - ChrfWhitespace::Unchanged => text.chars().collect(), - ChrfWhitespace::Ignore => text.chars().filter(|c| !c.is_whitespace()).collect(), - ChrfWhitespace::Collapse => collapse_whitespace(text.chars()), - } -} - -/// Collapse whitespace into single spaces. -/// Newlines and spaces are collapsed separately. -fn collapse_whitespace(chars: impl Iterator) -> Vec { - let mut result = Vec::new(); - let mut last_whitespace = None; - for c in chars { - if c.is_whitespace() && c != '\n' { - if last_whitespace != Some(' ') { - result.push(' '); - last_whitespace = Some(' '); - } - } else if c == '\n' { - if last_whitespace != Some('\n') { - result.push(c); - last_whitespace = Some('\n'); - } - } else { - result.push(c); - last_whitespace = None; - } - } - result -} - -/// Extract only the changed regions between two texts, with context for n-gram boundaries. -/// -/// Returns (original_affected_region, modified_affected_region) as Vec. -/// -/// The key insight: when computing n-gram delta between two nearly-identical texts, -/// n-grams from unchanged regions cancel out. We only need to process: -/// 1. The changed content itself -/// 2. CONTEXT_CHARS (n-1) characters before and after, to capture boundary-crossing n-grams -fn extract_changed_regions(original: &[char], modified: &[char]) -> (Vec, Vec) { - // Find longest common prefix - let prefix_len = original - .iter() - .zip(modified.iter()) - .take_while(|(a, b)| a == b) - .count(); - - // Find longest common suffix (that doesn't overlap with prefix) - let orig_remaining = original.len().saturating_sub(prefix_len); - let mod_remaining = modified.len().saturating_sub(prefix_len); - let max_suffix = orig_remaining.min(mod_remaining); - - let suffix_len = original - .iter() - .rev() - .zip(modified.iter().rev()) - .take(max_suffix) - .take_while(|(a, b)| a == b) - .count(); - - // Calculate the changed region boundaries - let orig_change_start = prefix_len; - let orig_change_end = original.len().saturating_sub(suffix_len); - let mod_change_start = prefix_len; - let mod_change_end = modified.len().saturating_sub(suffix_len); - - // If there's no actual change, return empty regions - if orig_change_start >= orig_change_end && mod_change_start >= mod_change_end { - return (Vec::new(), Vec::new()); - } - - // Expand to include context for n-gram boundaries - let orig_context_start = orig_change_start.saturating_sub(CONTEXT_CHARS); - let orig_context_end = (orig_change_end + CONTEXT_CHARS).min(original.len()); - let mod_context_start = mod_change_start.saturating_sub(CONTEXT_CHARS); - let mod_context_end = (mod_change_end + CONTEXT_CHARS).min(modified.len()); - - let orig_region: Vec = original[orig_context_start..orig_context_end].to_vec(); - let mod_region: Vec = modified[mod_context_start..mod_context_end].to_vec(); - - (orig_region, mod_region) -} - -/// Count n-grams directly from a char slice (avoids String allocation for the full text) -fn count_ngrams_from_chars(chars: &[char], n: usize) -> Counts { - let mut counts = Counts::default(); - - if chars.len() < n { - return counts; - } - - for window in chars.windows(n) { - let ngram: String = window.iter().collect(); - *counts.entry(ngram).or_insert(0) += 1; - } - - counts -} - -#[allow(dead_code)] -fn chr_f_ngram_counts(text: &str) -> Vec { - let text = match CHR_F_WHITESPACE { - ChrfWhitespace::Unchanged => text.to_string(), - ChrfWhitespace::Ignore => text - .chars() - .filter(|c| !c.is_whitespace()) - .collect::(), - ChrfWhitespace::Collapse => collapse_whitespace(text.chars()) - .into_iter() - .collect::(), - }; - - (1..=CHR_F_CHAR_ORDER) - .map(|order| count_ngrams(&text, order)) - .collect() -} - -fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta { - let mut delta = CountsDelta::default(); - - for (ngram, &before_count) in before { - let after_count = *after.get(ngram).unwrap_or(&0); - delta.insert(ngram.clone(), after_count as isize - before_count as isize); - } - - for (ngram, &after_count) in after { - if !before.contains_key(ngram) { - delta.insert(ngram.clone(), after_count as isize); - } - } - - delta -} - -/// Convert negative counts to special deletion tokens. -/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1}, -/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo" -/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive. -fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts { - let mut counts = Counts::default(); - - for (ngram, &delta) in delta { - if delta > 0 { - counts.insert(ngram.clone(), delta as usize); - } else if delta < 0 { - counts.insert(format!("¬{ngram}"), delta.unsigned_abs()); - } - } - - counts -} - -#[allow(dead_code)] -fn count_ngrams(text: &str, n: usize) -> Counts { - let chars: Vec = text.chars().collect(); - let mut counts = Counts::default(); - - for window in chars.windows(n) { - let ngram: String = window.iter().collect(); - *counts.entry(ngram).or_insert(0) += 1; - } - - counts -} - -pub fn braces_disbalance(text: &str) -> usize { - let mut disbalance = 0isize; - - let a = text.chars().filter(|&c| c == '{').count() as isize; - let b = text.chars().filter(|&c| c == '}').count() as isize; - disbalance += (a - b).abs(); - - let a = text.chars().filter(|&c| c == '(').count() as isize; - let b = text.chars().filter(|&c| c == ')').count() as isize; - disbalance += (a - b).abs(); - - let a = text.chars().filter(|&c| c == '[').count() as isize; - let b = text.chars().filter(|&c| c == ']').count() as isize; - disbalance += (a - b).abs(); - - disbalance as usize -} - -/// Extracts changed lines from a unified diff string. -/// Returns a bag (multiset) of lines that were added (+) or removed (-). -/// The +/- prefix is included in the line to distinguish additions from deletions. -pub fn extract_changed_lines_from_diff(diff: &str) -> Counts { - let mut counts = Counts::default(); - - for line in diff.lines() { - // Skip file headers (--- and +++) - if line.starts_with("---") || line.starts_with("+++") { - continue; - } - // Skip hunk headers (@@) - if line.starts_with("@@") { - continue; - } - // Skip diff header lines (diff --git, index, etc.) - if line.starts_with("diff ") || line.starts_with("index ") { - continue; - } - // Include added and removed lines (with their prefix) - if line.starts_with('+') || line.starts_with('-') { - *counts.entry(line.to_string()).or_insert(0) += 1; - } - } - - counts -} - -/// Computes exact lines match metrics between expected and actual patches. -/// Treats changed lines as a bag (multiset) - order is discarded but count matters. -/// Returns ClassificationMetrics with TP/FP/FN counts. -pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics { - let expected_lines = extract_changed_lines_from_diff(expected_patch); - let actual_lines = extract_changed_lines_from_diff(actual_patch); - ClassificationMetrics::from_counts(&expected_lines, &actual_lines) -} - -/// Returns whether the patch contains any isolated whitespace-only changes. -/// -/// A whitespace-only change is an added or deleted line whose content is empty or -/// contains only whitespace. It is "isolated" when it is not adjacent to any -/// substantive (non-whitespace) change within the same contiguous change group. pub fn has_isolated_whitespace_changes(patch_str: &str, cursor: Option<&ActualCursor>) -> bool { - let patch = Patch::parse_unified_diff(patch_str); - - let cursor_new_file_line = cursor.as_ref().map(|c| (c.row + 1) as usize); - - for hunk in &patch.hunks { - let lines = &hunk.lines; - let mut new_text_line = hunk.new_start as usize; - - for (i, line) in lines.iter().enumerate() { - let content = match line { - PatchLine::Addition(s) => { - let addition_line = new_text_line; - new_text_line += 1; - if s.trim().is_empty() && cursor_new_file_line == Some(addition_line) { - continue; - } - s.as_str() - } - PatchLine::Deletion(s) => s.as_str(), - PatchLine::Context(_) => { - new_text_line += 1; - continue; - } - _ => continue, - }; - - if !content.trim().is_empty() { - continue; - } - - if is_whitespace_change_isolated(lines, i) { - return true; - } - } - } - - false + edit_prediction_metrics::has_isolated_whitespace_changes( + patch_str, + cursor.map(|cursor| cursor.row), + ) } - -fn is_whitespace_change_isolated(lines: &[PatchLine], index: usize) -> bool { - // Look backward for a non-whitespace change before hitting a context line - for line in lines[..index].iter().rev() { - match line { - PatchLine::Addition(s) | PatchLine::Deletion(s) => { - if !s.trim().is_empty() { - return false; - } - } - _ => break, - } - } - - // Look forward for a non-whitespace change before hitting a context line - for line in &lines[index + 1..] { - match line { - PatchLine::Addition(s) | PatchLine::Deletion(s) => { - if !s.trim().is_empty() { - return false; - } - } - _ => break, - } - } - - true -} - -/// A simple proxy for whether the prediction respects editable region. -pub fn is_editable_region_correct(actual_patch: &str) -> bool { - // A typical sign of a wrong editable region: a bunch of lines deletion - // at the beginning or end of the patch. - let patch = Patch::parse_unified_diff(actual_patch); - if patch.hunks.is_empty() { - return true; - } - - let hunk = &patch.hunks[0]; - let mut deletions_at_start = 0; - - for line in hunk.lines.iter() { - match line { - PatchLine::Deletion(_) => deletions_at_start += 1, - _ => break, - } - } - - if deletions_at_start >= 3 { - return false; - } - - true -} - -#[derive(Debug, Default, Clone)] -pub struct TokenChangeCounts { - pub inserted_tokens: usize, - pub deleted_tokens: usize, -} - -/// Counts the number of inserted and deleted tokens in a unified diff patch. -/// -/// Tokens are words and whitespace sequences (as defined by `word_diff::tokenize`). -/// Within each hunk, the old (`-`) and new (`+`) lines are compared at the token level -/// using an LCS-based diff, so modified lines only count the actually changed tokens -/// rather than the entire line. -pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts { - let mut counts = TokenChangeCounts::default(); - let mut old_lines: Vec<&str> = Vec::new(); - let mut new_lines: Vec<&str> = Vec::new(); - - let flush = - |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, counts: &mut TokenChangeCounts| { - if old_lines.is_empty() && new_lines.is_empty() { - return; - } - - let old_text: String = old_lines - .iter() - .map(|line| if line.len() > 1 { &line[1..] } else { "" }) - .collect::>() - .join("\n"); - - let new_text: String = new_lines - .iter() - .map(|line| if line.len() > 1 { &line[1..] } else { "" }) - .collect::>() - .join("\n"); - - let old_tokens = tokenize(&old_text); - let new_tokens = tokenize(&new_text); - let ops = diff_tokens(&old_tokens, &new_tokens); - - for op in ops { - match op { - DiffOp::Equal(..) => {} - DiffOp::Delete(start, end) => { - counts.deleted_tokens += end - start; - } - DiffOp::Insert(start, end) => { - counts.inserted_tokens += end - start; - } - DiffOp::Replace { - old_start, - old_end, - new_start, - new_end, - } => { - counts.deleted_tokens += old_end - old_start; - counts.inserted_tokens += new_end - new_start; - } - } - } - - old_lines.clear(); - new_lines.clear(); - }; - - for line in patch.lines() { - if line.starts_with("---") - || line.starts_with("+++") - || line.starts_with("@@") - || line.starts_with("diff ") - || line.starts_with("index ") - { - flush(&mut old_lines, &mut new_lines, &mut counts); - } else if line.starts_with('-') { - old_lines.push(line); - } else if line.starts_with('+') { - new_lines.push(line); - } else { - flush(&mut old_lines, &mut new_lines, &mut counts); - } - } - - flush(&mut old_lines, &mut new_lines, &mut counts); - counts -} - -#[cfg(test)] -mod test_optimization { - use super::*; - - #[test] - fn test_extract_changed_regions_simple() { - let original: Vec = "hello world".chars().collect(); - let modified: Vec = "hello there".chars().collect(); - - let (orig_region, mod_region) = extract_changed_regions(&original, &modified); - - // "world" vs "there" - with 5 chars context, we get "ello world" vs "ello there" - // (or less if not enough chars available) - assert!(orig_region.len() < original.len()); - assert!(mod_region.len() < modified.len()); - } - - #[test] - fn test_extract_changed_regions_insertion() { - let original: Vec = "abcdef".chars().collect(); - let modified: Vec = "abcXYZdef".chars().collect(); - - let (orig_region, mod_region) = extract_changed_regions(&original, &modified); - - // The insertion is between c and d, so we need context around that point - assert!(orig_region.len() <= original.len()); - assert!(mod_region.iter().collect::().contains("XYZ")); - } - - #[test] - fn test_extract_changed_regions_identical() { - let text: Vec = "identical text".chars().collect(); - - let (orig_region, mod_region) = extract_changed_regions(&text, &text); - - // When texts are identical, regions should be empty - assert!(orig_region.is_empty()); - assert!(mod_region.is_empty()); - } - - #[test] - fn test_optimized_matches_original_score() { - // Test that our optimized version produces the same results - let test_cases = vec![ - ("hello world", "hello there", "hello world"), - ( - "fn main() {}", - "fn main() { println!(); }", - "fn main() { print!(); }", - ), - ("abcdefghij", "abcXXXghij", "abcYYghij"), - ("unchanged", "unchanged", "unchanged"), - ( - "prefix middle suffix", - "prefix CHANGED suffix", - "prefix middle suffix", - ), - ]; - - for (original, expected, actual) in test_cases { - let score = delta_chr_f(original, expected, actual).score; - // Just verify it produces a reasonable score (0-100) - assert!( - score >= 0.0 && score <= 100.0, - "Score {} out of range for ({}, {}, {})", - score, - original, - expected, - actual - ); - } - } - - #[test] - fn test_optimized_equals_reference() { - // Comprehensive test that optimized version matches reference implementation exactly - let test_cases = vec![ - // Basic cases - ("hello world", "hello there", "hello world"), - ("hello world", "hello there", "hello there"), - ("unchanged", "unchanged", "unchanged"), - // Code-like cases - ( - "fn main() { println!(\"Hello\"); }", - "fn main() { println!(\"Hello, World!\"); }", - "fn main() { println!(\"Hello, World!\"); }", - ), - ( - "fn main() { println!(\"Hello\"); }", - "fn main() { println!(\"Hello, World!\"); }", - "fn main() { println!(\"Goodbye\"); }", - ), - // Insertion - ("abcdef", "abcXYZdef", "abcdef"), - ("abcdef", "abcXYZdef", "abcXYZdef"), - ("abcdef", "abcXYZdef", "abcABCdef"), - // Deletion - ("abcXYZdef", "abcdef", "abcXYZdef"), - ("abcXYZdef", "abcdef", "abcdef"), - // Multiple changes (simulated by different expected/actual) - ("one two three four", "one THREE four", "one two FOUR"), - // Edge cases - ("a", "b", "c"), - ("", "abc", ""), - ("abc", "", "abc"), - // Longer text with small change - ( - "This is a longer piece of text that contains many words and characters to process", - "This is a longer piece of TEXT that contains many words and characters to process", - "This is a longer piece of text that contains many words and characters to process", - ), - // Change at the beginning - ( - "ORIGINAL start of text", - "NEW start of text", - "DIFFERENT start of text", - ), - // Change at the end - ( - "text ending ORIGINAL", - "text ending NEW", - "text ending DIFFERENT", - ), - // Whitespace (should be ignored) - ("hello world", "hello there", "hello world"), - ("a b c d", "a X c d", "a Y c d"), - ]; - - for (original, expected, actual) in test_cases { - let optimized_metrics = delta_chr_f(original, expected, actual); - let reference_metrics = delta_chr_f_reference(original, expected, actual); - - assert!( - (optimized_metrics.score - reference_metrics.score).abs() < 1e-10, - "Score mismatch for ({:?}, {:?}, {:?}):\n optimized: {}\n reference: {}", - original, - expected, - actual, - optimized_metrics.score, - reference_metrics.score - ); - assert_eq!( - optimized_metrics.counts.true_positives, - reference_metrics.counts.true_positives - ); - assert_eq!( - optimized_metrics.counts.false_positives, - reference_metrics.counts.false_positives - ); - assert_eq!( - optimized_metrics.counts.false_negatives, - reference_metrics.counts.false_negatives - ); - assert!((optimized_metrics.precision - reference_metrics.precision).abs() < 1e-10); - assert!((optimized_metrics.recall - reference_metrics.recall).abs() < 1e-10); - } - } - - #[test] - fn test_delta_chr_f_metrics_include_counts_and_rates() { - let original = "one two three"; - let expected = "one three"; - let actual = "one two four"; - - let metrics = delta_chr_f(original, expected, actual); - - assert!(metrics.score > 20.0 && metrics.score < 40.0); - assert!(metrics.counts.true_positives > 0); - assert!(metrics.counts.false_positives > 0); - assert!(metrics.counts.false_negatives > 0); - assert!(metrics.precision > 0.0 && metrics.precision < 1.0); - assert!(metrics.recall > 0.0 && metrics.recall < 1.0); - assert_eq!(metrics.beta, CHR_F_BETA); - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::example::ActualCursor; - use indoc::indoc; - - fn cursor_on_line(one_based_line: u32) -> ActualCursor { - ActualCursor { - path: String::new(), - row: one_based_line - 1, - column: 0, - offset: 0, - editable_region_offset: None, - } - } - - #[test] - fn test_delta_chr_f_perfect_match() { - let original = "fn main() { println!(\"Hello\");}"; - let expected = "fn main() { println!(\"Hello, World!\");}"; - - let score = delta_chr_f(original, expected, expected).score; - assert!((score - 100.0).abs() < 1e-2); - } - - #[test] - fn test_delta_chr_f_wrong_edit() { - // When the edit is wrong - let original = "one two three"; - let expected = "one three"; // deleted "two " - let actual = "one two four"; // deleted "three", added "four" - - // Then the score should be low - let score = delta_chr_f(original, expected, actual).score; - assert!(score > 20.0 && score < 40.0); - } - - #[test] - fn test_delta_chr_f_partial_match() { - let original = "let x = 42;"; - let expected = "let x = 100;"; - let actual = "let x = 99;"; - - // We got the edit location right, but the replacement text is wrong. - // Deleted ngrams will match, bringing the score somewhere in the middle. - let score = delta_chr_f(original, expected, actual).score; - assert!(score > 40.0 && score < 60.0); - } - - #[test] - fn test_delta_chr_f_missed_edit() { - // When predictions makes no changes - let original = "prefix old suffix"; - let expected = "prefix new suffix"; - let actual = "prefix old suffix"; // no change - - // Then the score should be low (all expected changes are false negatives) - let score = delta_chr_f(original, expected, actual).score; - assert!(score < 20.0); - } - - #[test] - fn test_delta_chr_f_extra_edit() { - // When adding unexpected content - let original = "helloworld"; - let expected = "helloworld"; // no change expected - let actual = "helloextraworld"; // added "extra" - - // Then the score should be low (all actual changes are false positives) - let score = delta_chr_f(original, expected, actual).score; - assert!(score < 20.0); - } - - #[test] - fn test_delta_chr_f_no_changes() { - let text = "unchanged text"; - let score = delta_chr_f(text, text, text).score; - assert!((score - 100.0).abs() < 1e-2); - } - - #[test] - fn test_braces_disbalance() { - let text = "let x = { 1 + 2 };"; - assert_eq!(braces_disbalance(text), 0); - - let text = "let x = { 1 + 2"; - assert_eq!(braces_disbalance(text), 1); - - let text = "let x = { 1 + 2 )"; - assert_eq!(braces_disbalance(text), 2); - } - - #[test] - fn test_extract_changed_lines_from_diff() { - let diff = r#"--- a/file.rs -+++ b/file.rs -@@ -1,3 +1,3 @@ - fn main() { -- println!("hello"); -+ println!("world"); - }"#; - - let counts = extract_changed_lines_from_diff(diff); - assert_eq!(counts.get("- println!(\"hello\");"), Some(&1)); - assert_eq!(counts.get("+ println!(\"world\");"), Some(&1)); - assert_eq!(counts.len(), 2); - } - - #[test] - fn test_extract_changed_lines_skips_headers() { - let diff = r#"diff --git a/file.rs b/file.rs -index abc123..def456 100644 ---- a/file.rs -+++ b/file.rs -@@ -1,2 +1,2 @@ --old line -+new line"#; - - let counts = extract_changed_lines_from_diff(diff); - assert_eq!(counts.get("-old line"), Some(&1)); - assert_eq!(counts.get("+new line"), Some(&1)); - assert_eq!(counts.len(), 2); - } - - #[test] - fn test_exact_lines_match_perfect() { - let expected = r#"--- a/file.rs -+++ b/file.rs -@@ -1,3 +1,3 @@ --old line 1 --old line 2 -+new line 1 -+new line 2"#; - - let actual = r#"--- a/file.rs -+++ b/file.rs -@@ -1,3 +1,3 @@ --old line 1 --old line 2 -+new line 1 -+new line 2"#; - - let metrics = exact_lines_match(expected, actual); - assert_eq!(metrics.true_positives, 4); - assert_eq!(metrics.false_positives, 0); - assert_eq!(metrics.false_negatives, 0); - assert!((metrics.precision() - 1.0).abs() < 1e-6); - assert!((metrics.recall() - 1.0).abs() < 1e-6); - assert!((metrics.f1() - 1.0).abs() < 1e-6); - } - - #[test] - fn test_exact_lines_match_partial() { - let expected = r#"-old line 1 --old line 2 -+new line 1 -+new line 2"#; - - let actual = r#"-old line 1 -+new line 1 -+extra line"#; - - let metrics = exact_lines_match(expected, actual); - // TP: "-old line 1" and "+new line 1" (2) - // FP: "+extra line" (1) - // FN: "-old line 2" and "+new line 2" (2) - assert_eq!(metrics.true_positives, 2); - assert_eq!(metrics.false_positives, 1); - assert_eq!(metrics.false_negatives, 2); - } - - #[test] - fn test_exact_lines_match_no_overlap() { - let expected = r#"-line a -+line b"#; - - let actual = r#"-line x -+line y"#; - - let metrics = exact_lines_match(expected, actual); - assert_eq!(metrics.true_positives, 0); - assert_eq!(metrics.false_positives, 2); - assert_eq!(metrics.false_negatives, 2); - assert!((metrics.precision()).abs() < 1e-6); - assert!((metrics.recall()).abs() < 1e-6); - } - - #[test] - fn test_exact_lines_match_duplicate_lines() { - let expected = r#"+line a -+line a -+line a"#; - - let actual = r#"+line a -+line a"#; - - let metrics = exact_lines_match(expected, actual); - // Expected has 3 "+line a", actual has 2 - // TP: 2, FN: 1, FP: 0 - assert_eq!(metrics.true_positives, 2); - assert_eq!(metrics.false_positives, 0); - assert_eq!(metrics.false_negatives, 1); - } - - #[test] - fn test_exact_lines_match_empty_patches() { - let metrics = exact_lines_match("", ""); - assert_eq!(metrics.true_positives, 0); - assert_eq!(metrics.false_positives, 0); - assert_eq!(metrics.false_negatives, 0); - } - - #[test] - fn test_is_editable_region_correct() { - let patch = indoc! {" - @@ -1,1 +1,1 @@ - -context - -removed - -from the beginning of the file - import sys - +sys.exit(0) - - "}; - assert!(!is_editable_region_correct(patch)); - - let patch = indoc! {" - @@ -1,1 +1,1 @@ - "}; - assert!(is_editable_region_correct(patch)); - } - - #[test] - fn test_isolated_whitespace_purely_whitespace_patch() { - let patch = indoc! {" - @@ -1,3 +1,4 @@ - fn main() { - + - println!(\"hello\"); - } - "}; - assert!(has_isolated_whitespace_changes(patch, None)); - } - - #[test] - fn test_isolated_whitespace_adjacent_to_real_change() { - let patch = indoc! {" - @@ -1,3 +1,4 @@ - fn main() { - + - + let x = 1; - println!(\"hello\"); - } - "}; - assert!(!has_isolated_whitespace_changes(patch, None)); - } - - #[test] - fn test_isolated_whitespace_no_whitespace_changes() { - let patch = indoc! {" - @@ -1,3 +1,3 @@ - fn main() { - - println!(\"hello\"); - + println!(\"world\"); - } - "}; - assert!(!has_isolated_whitespace_changes(patch, None)); - } - - #[test] - fn test_isolated_whitespace_deletion() { - let patch = indoc! {" - @@ -1,4 +1,3 @@ - fn main() { - - - println!(\"hello\"); - } - "}; - assert!(has_isolated_whitespace_changes(patch, None)); - } - - #[test] - fn test_isolated_whitespace_mixed_groups() { - let patch = indoc! {" - @@ -1,7 +1,8 @@ - fn main() { - + - let x = 1; - - let y = 2; - + let y = 3; - - + - println!(\"hello\"); - } - "}; - assert!(has_isolated_whitespace_changes(patch, None)); - } - - #[test] - fn test_isolated_whitespace_empty_patch() { - let patch = ""; - assert!(!has_isolated_whitespace_changes(patch, None)); - } - - #[test] - fn test_isolated_whitespace_skipped_on_cursor_line() { - // The addition of a blank line at new-file line 2 should be skipped - // because the cursor is on that line. - let patch = indoc! {" - @@ -1,3 +1,4 @@ - fn main() { - + - println!(\"hello\"); - } - "}; - // New-file line 2 is the added blank line - let cursor = cursor_on_line(2); - assert!(!has_isolated_whitespace_changes(patch, Some(&cursor))); - } - - #[test] - fn test_isolated_whitespace_not_skipped_when_cursor_on_different_line() { - // The blank line is at new-file line 2, but the cursor is on line 1. - let patch = indoc! {" - @@ -1,3 +1,4 @@ - fn main() { - + - println!(\"hello\"); - } - "}; - let cursor = cursor_on_line(1); - assert!(has_isolated_whitespace_changes(patch, Some(&cursor))); - } - - #[test] - fn test_isolated_whitespace_deletion_not_skipped_by_cursor() { - // Deletions don't have a new-file line, so cursor can't suppress them. - let patch = indoc! {" - @@ -1,4 +1,3 @@ - fn main() { - - - println!(\"hello\"); - } - "}; - let cursor = cursor_on_line(2); - assert!(has_isolated_whitespace_changes(patch, Some(&cursor))); - } - - #[test] - fn test_count_patch_token_changes_real_world_rename() { - // Real-world patch that was reported as returning 0 tokens - let patch = "--- a/sip_call\\README.md\n+++ b/sip_call\\README.md\n@@ -1,1 +1,1 @@\n-# \n+# SIP Call\n"; - let counts = count_patch_token_changes(patch); - // "# " vs "# SIP Call" — the "SIP" and "Call" tokens (and a whitespace token) are inserted - assert!( - counts.inserted_tokens > 0, - "expected inserted tokens > 0, got {}", - counts.inserted_tokens - ); - assert_eq!(counts.deleted_tokens, 0); - } - - #[test] - fn test_count_patch_token_changes_real_world_expansion() { - // Real-world patch: single token expanded to multiple lines - let patch = "--- a/task1/src/app/app.html\n+++ b/task1/src/app/app.html\n@@ -1,7 +1,9 @@\n \n \n
\n \n
\n"; - let counts = count_patch_token_changes(patch); - assert!( - counts.inserted_tokens > 0, - "expected inserted tokens > 0, got {}", - counts.inserted_tokens - ); - assert!( - counts.deleted_tokens > 0, - "expected deleted tokens > 0, got {}", - counts.deleted_tokens - ); - } - - #[test] - fn test_count_patch_token_changes_simple_replacement() { - let patch = indoc! {" - @@ -1,3 +1,3 @@ - fn main() { - - println!(\"hello\"); - + println!(\"world\"); - } - "}; - let counts = count_patch_token_changes(patch); - assert_eq!(counts.deleted_tokens, 1, "deleted: \"hello\""); - assert_eq!(counts.inserted_tokens, 1, "inserted: \"world\""); - } - - #[test] - fn test_count_patch_token_changes_insertion_only() { - let patch = indoc! {" - @@ -1,2 +1,3 @@ - fn main() { - + println!(\"hello\"); - } - "}; - let counts = count_patch_token_changes(patch); - assert_eq!(counts.deleted_tokens, 0); - assert!(counts.inserted_tokens > 0); - } - - #[test] - fn test_count_patch_token_changes_deletion_only() { - let patch = indoc! {" - @@ -1,3 +1,2 @@ - fn main() { - - println!(\"hello\"); - } - "}; - let counts = count_patch_token_changes(patch); - assert!(counts.deleted_tokens > 0); - assert_eq!(counts.inserted_tokens, 0); - } - - #[test] - fn test_count_patch_token_changes_empty_patch() { - let patch = ""; - let counts = count_patch_token_changes(patch); - assert_eq!(counts.deleted_tokens, 0); - assert_eq!(counts.inserted_tokens, 0); - } - - #[test] - fn test_count_patch_token_changes_multiple_hunks() { - let patch = indoc! {" - @@ -1,3 +1,3 @@ - fn main() { - - let x = 1; - + let x = 2; - } - @@ -10,3 +10,3 @@ - fn other() { - - let y = 3; - + let y = 4; - } - "}; - let counts = count_patch_token_changes(patch); - assert_eq!(counts.deleted_tokens, 2, "deleted: \"1\" and \"3\""); - assert_eq!(counts.inserted_tokens, 2, "inserted: \"2\" and \"4\""); - } - - #[test] - fn test_count_patch_token_changes_multiword_change() { - let patch = indoc! {" - @@ -1,1 +1,1 @@ - -hello world foo - +hello bar baz - "}; - let counts = count_patch_token_changes(patch); - // "world" and "foo" deleted, "bar" and "baz" inserted - // (whitespace tokens between them may also count) - assert!(counts.deleted_tokens >= 2); - assert!(counts.inserted_tokens >= 2); - } - - #[test] - fn test_whitespace_collapse() { - let text = "abc \n\n\n 123"; - let collapsed = collapse_whitespace(text.chars()); - assert_eq!( - collapsed, - vec!['a', 'b', 'c', ' ', '\n', ' ', '1', '2', '3'] - ); - } -} - -pub use edit_prediction::metrics::compute_kept_rate; diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index 34ddfd5f5ec0edca2b5de64a6f033a6463dcc133..58d52ed84e6eb8aaba621f4251e751ac89e17c02 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -1,2016 +1,17 @@ -use std::ops::Range; use std::path::Path; -use std::sync::Arc; - -use language::{char_diff, text_diff}; -use zeta_prompt::udiff::apply_diff_to_string; use zeta_prompt::ZetaPromptInput; -fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String { - let hunks = parse_diff_hunks(diff_str); - let mut result = text.to_string(); - - for hunk in hunks { - let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk)); - if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) { - result = updated; - } - } - - result -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct ParsedHunk { - old_start: u32, - old_count: u32, - new_start: u32, - new_count: u32, - lines: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -enum HunkLine { - Context(String), - Addition(String), - Deletion(String), -} - -fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> { - let line = line.strip_prefix("@@ -")?; - let (old_part, rest) = line.split_once(' ')?; - let rest = rest.strip_prefix('+')?; - let (new_part, _) = rest.split_once(" @@")?; - - let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') { - (start.parse().ok()?, count.parse().ok()?) - } else { - (old_part.parse().ok()?, 1) - }; - - let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') { - (start.parse().ok()?, count.parse().ok()?) - } else { - (new_part.parse().ok()?, 1) - }; - - Some((old_start, old_count, new_start, new_count)) -} - -fn parse_diff_hunks(diff: &str) -> Vec { - let mut hunks = Vec::new(); - let mut current_hunk: Option = None; - - for line in diff.lines() { - if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) { - if let Some(hunk) = current_hunk.take() { - hunks.push(hunk); - } - current_hunk = Some(ParsedHunk { - old_start, - old_count, - new_start, - new_count, - lines: Vec::new(), - }); - } else if let Some(ref mut hunk) = current_hunk { - if let Some(stripped) = line.strip_prefix('+') { - hunk.lines.push(HunkLine::Addition(stripped.to_string())); - } else if let Some(stripped) = line.strip_prefix('-') { - hunk.lines.push(HunkLine::Deletion(stripped.to_string())); - } else if let Some(stripped) = line.strip_prefix(' ') { - hunk.lines.push(HunkLine::Context(stripped.to_string())); - } else if line.is_empty() { - hunk.lines.push(HunkLine::Context(String::new())); - } - } - } - - if let Some(hunk) = current_hunk { - hunks.push(hunk); - } - - hunks -} - -fn format_hunk(hunk: &ParsedHunk) -> String { - let mut result = format!( - "@@ -{},{} +{},{} @@\n", - hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count - ); - for line in &hunk.lines { - match line { - HunkLine::Context(text) => { - result.push(' '); - result.push_str(text); - result.push('\n'); - } - HunkLine::Addition(text) => { - result.push('+'); - result.push_str(text); - result.push('\n'); - } - HunkLine::Deletion(text) => { - result.push('-'); - result.push_str(text); - result.push('\n'); - } - } - } - result -} - -fn filter_diff_hunks_by_excerpt( - diff: &str, - excerpt_start_row: u32, - excerpt_row_count: u32, -) -> (String, i32) { - let hunks = parse_diff_hunks(diff); - let excerpt_start_0based = excerpt_start_row; - let excerpt_end_0based = excerpt_start_row + excerpt_row_count; - - let mut filtered_hunks = Vec::new(); - let mut cumulative_line_offset: i32 = 0; - - for hunk in hunks { - let hunk_start_0based = hunk.new_start.saturating_sub(1); - let hunk_end_0based = hunk_start_0based + hunk.new_count; - - let additions: i32 = hunk - .lines - .iter() - .filter(|l| matches!(l, HunkLine::Addition(_))) - .count() as i32; - let deletions: i32 = hunk - .lines - .iter() - .filter(|l| matches!(l, HunkLine::Deletion(_))) - .count() as i32; - let hunk_line_delta = additions - deletions; - - if hunk_end_0based <= excerpt_start_0based { - cumulative_line_offset += hunk_line_delta; - continue; - } - - if hunk_start_0based >= excerpt_end_0based { - continue; - } - - let mut filtered_lines = Vec::new(); - let mut current_row_0based = hunk_start_0based; - let mut filtered_old_count = 0u32; - let mut filtered_new_count = 0u32; - let mut first_included_row: Option = None; - - for line in &hunk.lines { - match line { - HunkLine::Context(text) => { - if current_row_0based >= excerpt_start_0based - && current_row_0based < excerpt_end_0based - { - if first_included_row.is_none() { - first_included_row = Some(current_row_0based); - } - filtered_lines.push(HunkLine::Context(text.clone())); - filtered_old_count += 1; - filtered_new_count += 1; - } - current_row_0based += 1; - } - HunkLine::Addition(text) => { - if current_row_0based >= excerpt_start_0based - && current_row_0based < excerpt_end_0based - { - if first_included_row.is_none() { - first_included_row = Some(current_row_0based); - } - filtered_lines.push(HunkLine::Addition(text.clone())); - filtered_new_count += 1; - } - current_row_0based += 1; - } - HunkLine::Deletion(text) => { - if current_row_0based >= excerpt_start_0based - && current_row_0based < excerpt_end_0based - { - if first_included_row.is_none() { - first_included_row = Some(current_row_0based); - } - filtered_lines.push(HunkLine::Deletion(text.clone())); - filtered_old_count += 1; - } - } - } - } - - if !filtered_lines.is_empty() { - let first_row = first_included_row.unwrap_or(excerpt_start_0based); - let new_start_1based = (first_row - excerpt_start_0based) + 1; - - filtered_hunks.push(ParsedHunk { - old_start: new_start_1based, - old_count: filtered_old_count, - new_start: new_start_1based, - new_count: filtered_new_count, - lines: filtered_lines, - }); - } - - cumulative_line_offset += hunk_line_delta; - } - - let mut result = String::new(); - for hunk in &filtered_hunks { - result.push_str(&format_hunk(hunk)); - } - - (result, cumulative_line_offset) -} - -fn compute_excerpt_aware_reversal_overlap( - edit_history_diffs: &[&str], - excerpt_content: &str, - excerpt_start_row: u32, - predicted_content: &str, -) -> ReversalOverlap { - let mut current_content = excerpt_content.to_string(); - let mut current_excerpt_start_row = excerpt_start_row; - - for diff in edit_history_diffs.iter().rev() { - if diff.is_empty() { - continue; - } - - let current_row_count = current_content.lines().count() as u32; - let (filtered_diff, _line_offset) = - filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1)); - - if filtered_diff.is_empty() { - let hunks = parse_diff_hunks(diff); - for hunk in hunks { - let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count; - if hunk_end <= current_excerpt_start_row { - let additions: u32 = hunk - .lines - .iter() - .filter(|l| matches!(l, HunkLine::Addition(_))) - .count() as u32; - let deletions: u32 = hunk - .lines - .iter() - .filter(|l| matches!(l, HunkLine::Deletion(_))) - .count() as u32; - if additions >= deletions { - current_excerpt_start_row = - current_excerpt_start_row.saturating_sub(additions - deletions); - } else { - current_excerpt_start_row += deletions - additions; - } - } - } - continue; - } - - let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff)); - match apply_diff_to_string(&reversed, ¤t_content) { - Ok(updated) => { - current_content = updated; - } - Err(_) => { - continue; - } - } - - let hunks = parse_diff_hunks(diff); - for hunk in hunks { - let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count; - if hunk_end <= current_excerpt_start_row { - let additions: u32 = hunk - .lines - .iter() - .filter(|l| matches!(l, HunkLine::Addition(_))) - .count() as u32; - let deletions: u32 = hunk - .lines - .iter() - .filter(|l| matches!(l, HunkLine::Deletion(_))) - .count() as u32; - if additions >= deletions { - current_excerpt_start_row = - current_excerpt_start_row.saturating_sub(additions - deletions); - } else { - current_excerpt_start_row += deletions - additions; - } - } - } - } - - compute_reversal_overlap(¤t_content, excerpt_content, predicted_content) -} - -fn reverse_diff(diff: &str) -> String { - let mut result: String = diff - .lines() - .map(|line| { - if line.starts_with("--- ") { - line.replacen("--- ", "+++ ", 1) - } else if line.starts_with("+++ ") { - line.replacen("+++ ", "--- ", 1) - } else if line.starts_with('+') && !line.starts_with("+++") { - format!("-{}", &line[1..]) - } else if line.starts_with('-') && !line.starts_with("---") { - format!("+{}", &line[1..]) - } else { - line.to_string() - } - }) - .collect::>() - .join("\n"); - if diff.ends_with('\n') { - result.push('\n'); - } - result -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct GranularEdit { - range: Range, - old_text: String, - new_text: String, -} - -fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec { - text_diff(old_text, new_text) - .into_iter() - .map(|(range, new_text)| GranularEdit { - old_text: old_text[range.clone()].to_string(), - range, - new_text: new_text.to_string(), - }) - .collect() -} - -#[derive(Debug, Clone)] -struct HistoryAdditionRange { - range_in_current: Range, -} - -#[derive(Debug, Clone)] -struct HistoryDeletionRange { - deleted_text: String, - position_in_current: usize, -} - -fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec { - let mut result = Vec::new(); - let mut offset_delta: isize = 0; - - for edit in history_edits { - if !edit.new_text.is_empty() { - let new_start = (edit.range.start as isize + offset_delta) as usize; - let new_end = new_start + edit.new_text.len(); - result.push(HistoryAdditionRange { - range_in_current: new_start..new_end, - }); - } - - offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize; - } - - result -} - -fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec { - let mut result = Vec::new(); - let mut offset_delta: isize = 0; - - for edit in history_edits { - if !edit.old_text.is_empty() { - let position_in_current = (edit.range.start as isize + offset_delta) as usize; - result.push(HistoryDeletionRange { - deleted_text: edit.old_text.clone(), - position_in_current, - }); - } - - offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize; - } - - result -} - -#[derive(Debug, Clone, Default, PartialEq, Eq)] -struct ReversalOverlap { - chars_reversing_user_edits: usize, - total_chars_in_prediction: usize, -} - -impl ReversalOverlap { - fn ratio(&self) -> f32 { - if self.total_chars_in_prediction == 0 { - 0.0 - } else { - self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32 - } - } -} - -/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension), -/// or where `new_text` appears as a subsequence within `old_text` (reduction). -/// -/// For extensions: when the user's text is preserved (in order) within the prediction, -/// we only count the newly inserted characters, not the preserved ones. -/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()") -/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars -/// -/// For reductions: when the prediction's text is preserved (in order) within the original, -/// we only count the deleted characters, not the preserved ones. -/// E.g., "ifrom" → "from" becomes 1 deleted char ("i") -fn normalize_extension_edits(edits: Vec) -> Vec { - edits - .into_iter() - .flat_map(|edit| { - if edit.old_text.is_empty() || edit.new_text.is_empty() { - return vec![edit]; - } - - // Use character-wise diff to find exact byte ranges of changes - let char_edits = char_diff(&edit.old_text, &edit.new_text); - - let all_deletions = !char_edits.is_empty() - && char_edits - .iter() - .all(|(range, replacement)| !range.is_empty() && replacement.is_empty()); - let all_insertions = !char_edits.is_empty() - && char_edits - .iter() - .all(|(range, replacement)| range.is_empty() && !replacement.is_empty()); - if all_deletions || all_insertions { - return char_edits - .into_iter() - .map(|(range, replacement)| GranularEdit { - range: edit.range.start + range.start..edit.range.start + range.end, - old_text: edit.old_text[range].to_string(), - new_text: replacement.to_string(), - }) - .collect(); - } - - // Otherwise, keep the original edit (mixed changes) - vec![edit] - }) - .collect() -} - -fn compute_reversal_overlap( - original_content: &str, - current_content: &str, - predicted_content: &str, -) -> ReversalOverlap { - let history_edits = - normalize_extension_edits(compute_granular_edits(original_content, current_content)); - let prediction_edits = - normalize_extension_edits(compute_granular_edits(current_content, predicted_content)); - - let history_addition_ranges = compute_history_addition_ranges(&history_edits); - let history_deletion_ranges = compute_history_deletion_ranges(&history_edits); - - let reversed_additions = - compute_reversed_additions(&history_addition_ranges, &prediction_edits); - let restored_deletions = - compute_restored_deletions(&history_deletion_ranges, &prediction_edits); - - let total_chars_in_prediction: usize = prediction_edits - .iter() - .map(|e| e.new_text.chars().count() + e.old_text.chars().count()) - .sum(); - - ReversalOverlap { - chars_reversing_user_edits: reversed_additions + restored_deletions, - total_chars_in_prediction, - } -} - -fn compute_reversed_additions( - history_addition_ranges: &[HistoryAdditionRange], - prediction_edits: &[GranularEdit], -) -> usize { - let mut reversed_chars = 0; - - for pred_edit in prediction_edits { - for history_addition in history_addition_ranges { - let overlap_start = pred_edit - .range - .start - .max(history_addition.range_in_current.start); - let overlap_end = pred_edit - .range - .end - .min(history_addition.range_in_current.end); - - if overlap_start < overlap_end { - let relative_start = overlap_start - pred_edit.range.start; - let relative_end = overlap_end - pred_edit.range.start; - let overlap_text = &pred_edit.old_text[relative_start..relative_end]; - reversed_chars += overlap_text.chars().count(); - } - } - } - - reversed_chars -} - -fn compute_restored_deletions( - history_deletion_ranges: &[HistoryDeletionRange], - prediction_edits: &[GranularEdit], -) -> usize { - let mut restored = 0; - - for pred_edit in prediction_edits { - if pred_edit.new_text.is_empty() { - continue; - } - - for deletion in history_deletion_ranges { - if pred_edit.range.contains(&deletion.position_in_current) - || deletion.position_in_current == pred_edit.range.start - { - restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text); - } - } - } - - restored -} - -fn compute_lcs_length(a: &str, b: &str) -> usize { - let a_chars: Vec = a.chars().collect(); - let b_chars: Vec = b.chars().collect(); - let m = a_chars.len(); - let n = b_chars.len(); - - if m == 0 || n == 0 { - return 0; - } - - let mut prev = vec![0; n + 1]; - let mut curr = vec![0; n + 1]; - - for i in 1..=m { - for j in 1..=n { - if a_chars[i - 1] == b_chars[j - 1] { - curr[j] = prev[j - 1] + 1; - } else { - curr[j] = prev[j].max(curr[j - 1]); - } - } - std::mem::swap(&mut prev, &mut curr); - curr.fill(0); - } - - prev[n] -} - -fn filter_edit_history_by_path<'a>( - edit_history: &'a [Arc], - cursor_path: &std::path::Path, -) -> Vec<&'a zeta_prompt::Event> { - edit_history - .iter() - .filter(|event| match event.as_ref() { - zeta_prompt::Event::BufferChange { path, .. } => { - let event_path = path.as_ref(); - if event_path == cursor_path { - return true; - } - let stripped = event_path - .components() - .skip(1) - .collect::(); - stripped == cursor_path - } - }) - .map(|arc| arc.as_ref()) - .collect() -} - -fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str { - match event { - zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(), - } -} - -fn is_predicted_event(event: &zeta_prompt::Event) -> bool { - match event { - zeta_prompt::Event::BufferChange { predicted, .. } => *predicted, - } -} - pub fn compute_prediction_reversal_ratio( prompt_inputs: &ZetaPromptInput, predicted_content: &str, cursor_path: &Path, ) -> f32 { - let current_content: &str = prompt_inputs.cursor_excerpt.as_ref(); - - let edit_history: &[Arc] = &prompt_inputs.events; - let relevant_events = filter_edit_history_by_path(edit_history, cursor_path); - - let most_recent = match relevant_events.last() { - Some(event) if !is_predicted_event(event) => *event, - _ => return 0.0, - }; - - let diff = extract_diff_from_event(most_recent); - if diff.is_empty() { - return 0.0; - } - - if let Some(excerpt_start_row) = prompt_inputs.excerpt_start_row { - let diffs = vec![diff]; - let overlap = compute_excerpt_aware_reversal_overlap( - &diffs, - current_content, - excerpt_start_row, - predicted_content, - ); - return overlap.ratio(); - } - - let reversed = reverse_diff(diff); - let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed); - let original_content = match apply_diff_to_string(&with_headers, current_content) { - Ok(updated_content) => updated_content, - Err(_) => apply_diff_to_string_lenient(&reversed, current_content), - }; - - let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content); - overlap.ratio() -} - -#[cfg(test)] -mod tests { - use super::*; - use indoc::indoc; - use zeta_prompt::ExcerptRanges; - use zeta_prompt::udiff::apply_diff_to_string; - - fn make_test_prompt_inputs( - content: &str, - events: Vec>, - excerpt_start_row: Option, - ) -> ZetaPromptInput { - ZetaPromptInput { - cursor_path: Arc::from(Path::new("src/test.rs")), - cursor_excerpt: content.into(), - cursor_offset_in_excerpt: 0, - excerpt_start_row, - events, - related_files: Some(Vec::new()), - active_buffer_diagnostics: Vec::new(), - excerpt_ranges: ExcerptRanges { - editable_150: 0..content.len(), - editable_180: 0..content.len(), - editable_350: 0..content.len(), - editable_150_context_350: 0..content.len(), - editable_180_context_350: 0..content.len(), - editable_350_context_150: 0..content.len(), - ..Default::default() - }, - syntax_ranges: None, - experiment: None, - in_open_source_repo: false, - can_collect_data: false, - repo_url: None, - } - } - - #[test] - fn test_reversal_overlap() { - struct Case { - name: &'static str, - original: &'static str, - current: &'static str, - predicted: &'static str, - expected_reversal_chars: usize, - expected_total_chars: usize, - } - - let cases = [ - Case { - name: "user_adds_line_prediction_removes_it", - original: indoc! {" - a - b - c"}, - current: indoc! {" - a - new line - b - c"}, - predicted: indoc! {" - a - b - c"}, - expected_reversal_chars: 9, - expected_total_chars: 9, - }, - Case { - name: "user_deletes_line_prediction_restores_it", - original: indoc! {" - a - deleted - b"}, - current: indoc! {" - a - b"}, - predicted: indoc! {" - a - deleted - b"}, - expected_reversal_chars: 8, - expected_total_chars: 8, - }, - Case { - name: "user_deletes_text_prediction_restores_partial", - original: "hello beautiful world", - current: "hello world", - predicted: "hello beautiful world", - expected_reversal_chars: 10, - expected_total_chars: 10, - }, - Case { - name: "user_deletes_foo_prediction_adds_bar", - original: "foo", - current: "", - predicted: "bar", - expected_reversal_chars: 0, - expected_total_chars: 3, - }, - Case { - name: "independent_edits_different_locations", - original: indoc! {" - line1 - line2 - line3"}, - current: indoc! {" - LINE1 - line2 - line3"}, - predicted: indoc! {" - LINE1 - line2 - LINE3"}, - expected_reversal_chars: 0, - expected_total_chars: 10, - }, - Case { - name: "no_history_edits", - original: "same", - current: "same", - predicted: "different", - expected_reversal_chars: 0, - expected_total_chars: 13, - }, - Case { - name: "user_replaces_text_prediction_reverses", - original: indoc! {" - keep - delete_me - keep2"}, - current: indoc! {" - keep - added - keep2"}, - predicted: indoc! {" - keep - delete_me - keep2"}, - expected_reversal_chars: 14, - expected_total_chars: 14, - }, - Case { - name: "user_modifies_word_prediction_modifies_differently", - original: "the quick brown fox", - current: "the slow brown fox", - predicted: "the fast brown fox", - expected_reversal_chars: 4, - expected_total_chars: 8, - }, - Case { - name: "user finishes function name (suffix)", - original: "", - current: "epr", - predicted: "eprintln!()", - expected_reversal_chars: 0, - expected_total_chars: 8, - }, - Case { - name: "user starts function name (prefix)", - original: "", - current: "my_function()", - predicted: "test_my_function()", - expected_reversal_chars: 0, - expected_total_chars: 5, - }, - Case { - name: "user types partial, prediction extends in multiple places", - original: "", - current: "test_my_function", - predicted: "a_test_for_my_special_function_plz", - expected_reversal_chars: 0, - expected_total_chars: 18, - }, - // Edge cases for subsequence matching - Case { - name: "subsequence with interleaved underscores", - original: "", - current: "a_b_c", - predicted: "_a__b__c__", - expected_reversal_chars: 0, - expected_total_chars: 5, - }, - Case { - name: "not a subsequence - different characters", - original: "", - current: "abc", - predicted: "xyz", - expected_reversal_chars: 3, - expected_total_chars: 6, - }, - Case { - name: "not a subsequence - wrong order", - original: "", - current: "abc", - predicted: "cba", - expected_reversal_chars: 3, - expected_total_chars: 6, - }, - Case { - name: "partial subsequence - only some chars match", - original: "", - current: "abcd", - predicted: "axbx", - expected_reversal_chars: 4, - expected_total_chars: 8, - }, - // Common completion patterns - Case { - name: "completing a method call", - original: "", - current: "vec.pu", - predicted: "vec.push(item)", - expected_reversal_chars: 0, - expected_total_chars: 8, - }, - Case { - name: "completing an import statement", - original: "", - current: "use std::col", - predicted: "use std::collections::HashMap", - expected_reversal_chars: 0, - expected_total_chars: 17, - }, - Case { - name: "completing a struct field", - original: "", - current: "name: St", - predicted: "name: String", - expected_reversal_chars: 0, - expected_total_chars: 4, - }, - Case { - name: "prediction replaces with completely different text", - original: "", - current: "hello", - predicted: "world", - expected_reversal_chars: 5, - expected_total_chars: 10, - }, - Case { - name: "empty prediction removes user text", - original: "", - current: "mistake", - predicted: "", - expected_reversal_chars: 7, - expected_total_chars: 7, - }, - Case { - name: "fixing typo is not reversal", - original: "", - current: "", - expected_reversal_chars: 0, - expected_total_chars: 2, - }, - Case { - name: "infix insertion not reversal", - original: indoc! {" - from my_project import Foo - "}, - current: indoc! {" - ifrom my_project import Foo - "}, - predicted: indoc! {" - import - from my_project import Foo - "}, - expected_reversal_chars: 0, - expected_total_chars: 6, - }, - Case { - name: "non-word based reversal", - original: "from", - current: "ifrom", - predicted: "from", - expected_reversal_chars: 1, - expected_total_chars: 1, - }, - Case { - name: "multiple insertions no reversal", - original: "print(\"Hello, World!\")", - current: "sys.(\"Hello, World!\")", - predicted: "sys.stdout.write(\"Hello, World!\\n\")", - expected_reversal_chars: 0, - expected_total_chars: 14, - }, - ]; - - for case in &cases { - let overlap = compute_reversal_overlap(case.original, case.current, case.predicted); - assert_eq!( - overlap.chars_reversing_user_edits, case.expected_reversal_chars, - "Test '{}': expected {} reversal chars, got {}", - case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits - ); - assert_eq!( - overlap.total_chars_in_prediction, case.expected_total_chars, - "Test '{}': expected {} total chars, got {}", - case.name, case.expected_total_chars, overlap.total_chars_in_prediction - ); - } - } - - #[test] - fn test_reverse_diff() { - let forward_diff = indoc! {" - --- a/file.rs - +++ b/file.rs - @@ -1,3 +1,4 @@ - fn main() { - + let x = 42; - println!(\"hello\"); - }"}; - - let reversed = reverse_diff(forward_diff); - - assert!( - reversed.contains("+++ a/file.rs"), - "Should have +++ for old path" - ); - assert!( - reversed.contains("--- b/file.rs"), - "Should have --- for new path" - ); - assert!( - reversed.contains("- let x = 42;"), - "Added line should become deletion" - ); - assert!( - reversed.contains(" fn main()"), - "Context lines should be unchanged" - ); - } - - #[test] - fn test_reverse_diff_roundtrip() { - // Applying a diff and then its reverse should get back to original - let original = indoc! {" - first line - hello world - last line - "}; - let modified = indoc! {" - first line - hello beautiful world - last line - "}; - - // unified_diff doesn't include file headers, but apply_diff_to_string needs them - let diff_body = language::unified_diff(original, modified); - let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body); - let reversed_diff = reverse_diff(&forward_diff); - - // Apply forward diff to original - let after_forward = apply_diff_to_string(&forward_diff, original).unwrap(); - assert_eq!(after_forward, modified); - - // Apply reversed diff to modified - let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap(); - assert_eq!(after_reverse, original); - } - - #[test] - fn test_filter_edit_history_by_path() { - // Test that filter_edit_history_by_path correctly matches paths when - // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs") - // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs") - let events = vec![ - Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("myrepo/src/file.rs")), - old_path: Arc::from(Path::new("myrepo/src/file.rs")), - diff: indoc! {" - @@ -1 +1 @@ - -old - +new"} - .into(), - predicted: false, - in_open_source_repo: true, - }), - Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("myrepo/other.rs")), - old_path: Arc::from(Path::new("myrepo/other.rs")), - diff: indoc! {" - @@ -1 +1 @@ - -a - +b"} - .into(), - predicted: false, - in_open_source_repo: true, - }), - Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("src/file.rs")), - old_path: Arc::from(Path::new("src/file.rs")), - diff: indoc! {" - @@ -1 +1 @@ - -x - +y"} - .into(), - predicted: false, - in_open_source_repo: true, - }), - ]; - - // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path - // "src/file.rs" exact match - let cursor_path = Path::new("src/file.rs"); - let filtered = filter_edit_history_by_path(&events, cursor_path); - assert_eq!( - filtered.len(), - 2, - "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)" - ); - - // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs" - // "src/file.rs" stripped -> "file.rs" == "file.rs" - let cursor_path = Path::new("file.rs"); - let filtered = filter_edit_history_by_path(&events, cursor_path); - assert_eq!( - filtered.len(), - 1, - "Should only match src/file.rs (stripped to file.rs)" - ); - - // "myrepo/other.rs" stripped -> "other.rs" == "other.rs" - let cursor_path = Path::new("other.rs"); - let filtered = filter_edit_history_by_path(&events, cursor_path); - assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs"); - } - - #[test] - fn test_reverse_diff_preserves_trailing_newline() { - let diff_with_trailing_newline = indoc! {" - --- a/file - +++ b/file - @@ -1 +1 @@ - -old - +new - "}; - let reversed = reverse_diff(diff_with_trailing_newline); - assert!( - reversed.ends_with('\n'), - "Reversed diff should preserve trailing newline" - ); - - let diff_without_trailing_newline = indoc! {" - --- a/file - +++ b/file - @@ -1 +1 @@ - -old - +new"}; - let reversed = reverse_diff(diff_without_trailing_newline); - assert!( - !reversed.ends_with('\n'), - "Reversed diff should not add trailing newline if original didn't have one" - ); - } - - #[test] - fn test_filter_hunks_by_excerpt_region() { - struct Case { - name: &'static str, - diff: &'static str, - excerpt_start_row: u32, - excerpt_row_count: u32, - expected_filtered_diff: &'static str, - expected_line_offset: i32, - } - - let cases = [ - Case { - name: "hunk_entirely_before_excerpt", - diff: indoc! {" - @@ -1,3 +1,4 @@ - line1 - +inserted - line2 - line3 - "}, - excerpt_start_row: 10, - excerpt_row_count: 5, - expected_filtered_diff: "", - expected_line_offset: 1, - }, - Case { - name: "hunk_entirely_inside_excerpt", - diff: indoc! {" - @@ -12,3 +12,4 @@ - line12 - +inserted - line13 - line14 - "}, - excerpt_start_row: 10, - excerpt_row_count: 10, - expected_filtered_diff: indoc! {" - @@ -2,3 +2,4 @@ - line12 - +inserted - line13 - line14 - "}, - expected_line_offset: 1, - }, - Case { - name: "hunk_entirely_after_excerpt", - diff: indoc! {" - @@ -50,3 +50,4 @@ - line50 - +inserted - line51 - line52 - "}, - excerpt_start_row: 10, - excerpt_row_count: 5, - expected_filtered_diff: "", - expected_line_offset: 0, - }, - Case { - name: "hunk_straddles_excerpt_start", - diff: indoc! {" - @@ -8,5 +8,6 @@ - line8 - line9 - +inserted - line10 - line11 - line12 - "}, - excerpt_start_row: 10, - excerpt_row_count: 10, - expected_filtered_diff: indoc! {" - @@ -1,3 +1,3 @@ - line10 - line11 - line12 - "}, - expected_line_offset: 1, - }, - Case { - name: "hunk_straddles_excerpt_end", - diff: indoc! {" - @@ -18,5 +18,6 @@ - line18 - line19 - +inserted - line20 - line21 - line22 - "}, - excerpt_start_row: 10, - excerpt_row_count: 10, - expected_filtered_diff: indoc! {" - @@ -8,2 +8,3 @@ - line18 - line19 - +inserted - "}, - expected_line_offset: 1, - }, - Case { - name: "multiple_hunks_mixed", - diff: indoc! {" - @@ -1,2 +1,3 @@ - line1 - +before_excerpt - line2 - @@ -12,2 +13,3 @@ - line12 - +inside_excerpt - line13 - @@ -50,2 +52,3 @@ - line50 - +after_excerpt - line51 - "}, - excerpt_start_row: 10, - excerpt_row_count: 10, - expected_filtered_diff: indoc! {" - @@ -3,2 +3,3 @@ - line12 - +inside_excerpt - line13 - "}, - expected_line_offset: 2, - }, - Case { - name: "deletion_before_excerpt", - diff: indoc! {" - @@ -1,4 +1,3 @@ - line1 - -deleted - line2 - line3 - "}, - excerpt_start_row: 10, - excerpt_row_count: 5, - expected_filtered_diff: "", - expected_line_offset: -1, - }, - Case { - name: "deletion_inside_excerpt", - diff: indoc! {" - @@ -12,4 +12,3 @@ - line12 - -deleted - line13 - line14 - "}, - excerpt_start_row: 10, - excerpt_row_count: 10, - expected_filtered_diff: indoc! {" - @@ -2,4 +2,3 @@ - line12 - -deleted - line13 - line14 - "}, - expected_line_offset: -1, - }, - Case { - name: "empty_diff", - diff: "", - excerpt_start_row: 10, - excerpt_row_count: 5, - expected_filtered_diff: "", - expected_line_offset: 0, - }, - Case { - name: "hunk_spans_entire_excerpt", - diff: indoc! {" - @@ -8,10 +8,12 @@ - line8 - line9 - line10 - line11 - +inserted1 - line12 - line13 - +inserted2 - line14 - line15 - line16 - line17 - "}, - excerpt_start_row: 10, - excerpt_row_count: 5, - expected_filtered_diff: indoc! {" - @@ -1,3 +1,5 @@ - line11 - +inserted1 - line12 - line13 - +inserted2 - "}, - expected_line_offset: 2, - }, - Case { - name: "replacement_inside_excerpt", - diff: indoc! {" - @@ -12,3 +12,3 @@ - line12 - -old_text - +new_text - line14 - "}, - excerpt_start_row: 10, - excerpt_row_count: 10, - expected_filtered_diff: indoc! {" - @@ -2,3 +2,3 @@ - line12 - -old_text - +new_text - line14 - "}, - expected_line_offset: 0, - }, - ]; - - for case in &cases { - let (filtered, line_offset) = filter_diff_hunks_by_excerpt( - case.diff, - case.excerpt_start_row, - case.excerpt_row_count, - ); - assert_eq!( - filtered, case.expected_filtered_diff, - "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}", - case.name, case.expected_filtered_diff, filtered - ); - assert_eq!( - line_offset, case.expected_line_offset, - "Test '{}': line offset mismatch. Expected {}, got {}", - case.name, case.expected_line_offset, line_offset - ); - } - } - - #[test] - fn test_excerpt_aware_reversal_tracking() { - struct Case { - name: &'static str, - edit_history_diffs: Vec<&'static str>, - excerpt_content: &'static str, - excerpt_start_row: u32, - predicted_content: &'static str, - expected_reversal_chars: usize, - expected_total_chars: usize, - } - - let cases = [ - Case { - name: "edit_outside_excerpt_no_reversal", - edit_history_diffs: vec![indoc! {" - @@ -1,2 +1,3 @@ - line1 - +added_outside - line2 - "}], - excerpt_content: indoc! {" - line10 - line11 - line12 - "}, - excerpt_start_row: 10, - predicted_content: indoc! {" - line10 - modified - line12 - "}, - expected_reversal_chars: 0, - expected_total_chars: 14, - }, - Case { - name: "edit_inside_excerpt_with_reversal", - edit_history_diffs: vec![indoc! {" - @@ -10,3 +10,4 @@ - line10 - +user_added - line11 - line12 - "}], - excerpt_content: indoc! {" - line10 - user_added - line11 - line12 - "}, - excerpt_start_row: 10, - predicted_content: indoc! {" - line10 - line11 - line12 - "}, - expected_reversal_chars: 11, - expected_total_chars: 11, - }, - Case { - name: "straddling_edit_partial_reversal", - edit_history_diffs: vec![indoc! {" - @@ -8,6 +8,8 @@ - line8 - line9 - +before_excerpt - line10 - +inside_excerpt - line11 - line12 - line13 - "}], - excerpt_content: indoc! {" - line10 - inside_excerpt - line11 - line12 - line13 - "}, - excerpt_start_row: 10, - predicted_content: indoc! {" - line10 - line11 - line12 - line13 - "}, - expected_reversal_chars: 15, - expected_total_chars: 15, - }, - Case { - name: "multiple_edits_mixed_locations", - edit_history_diffs: vec![ - indoc! {" - @@ -1,2 +1,3 @@ - line1 - +outside1 - line2 - "}, - indoc! {" - @@ -11,2 +12,3 @@ - line11 - +inside1 - line12 - "}, - ], - excerpt_content: indoc! {" - line10 - line11 - inside1 - line12 - line13 - "}, - excerpt_start_row: 10, - predicted_content: indoc! {" - line10 - line11 - line12 - line13 - "}, - expected_reversal_chars: 8, - expected_total_chars: 8, - }, - Case { - name: "no_edit_history", - edit_history_diffs: vec![], - excerpt_content: indoc! {" - line10 - line11 - line12 - "}, - excerpt_start_row: 10, - predicted_content: indoc! {" - line10 - modified - line12 - "}, - expected_reversal_chars: 0, - expected_total_chars: 14, - }, - Case { - name: "edit_after_excerpt_no_effect", - edit_history_diffs: vec![indoc! {" - @@ -50,2 +50,3 @@ - line50 - +added_after - line51 - "}], - excerpt_content: indoc! {" - line10 - line11 - line12 - "}, - excerpt_start_row: 10, - predicted_content: indoc! {" - line10 - changed - line12 - "}, - expected_reversal_chars: 0, - expected_total_chars: 13, - }, - Case { - name: "line_offset_tracking_across_hunks", - edit_history_diffs: vec![ - indoc! {" - @@ -1,2 +1,4 @@ - line1 - +added1 - +added2 - line2 - "}, - indoc! {" - @@ -12,2 +14,3 @@ - line12 - +inside_after_offset - line13 - "}, - ], - excerpt_content: indoc! {" - line10 - line11 - line12 - inside_after_offset - line13 - "}, - excerpt_start_row: 10, - predicted_content: indoc! {" - line10 - line11 - line12 - line13 - "}, - expected_reversal_chars: 20, - expected_total_chars: 20, - }, - ]; - - for case in &cases { - let overlap = compute_excerpt_aware_reversal_overlap( - &case.edit_history_diffs, - case.excerpt_content, - case.excerpt_start_row, - case.predicted_content, - ); - assert_eq!( - overlap.chars_reversing_user_edits, case.expected_reversal_chars, - "Test '{}': expected {} reversal chars, got {}", - case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits - ); - assert_eq!( - overlap.total_chars_in_prediction, case.expected_total_chars, - "Test '{}': expected {} total chars, got {}", - case.name, case.expected_total_chars, overlap.total_chars_in_prediction - ); - } - } - - #[test] - fn test_lenient_diff_application() { - struct Case { - name: &'static str, - diff: &'static str, - content: &'static str, - expected_result: &'static str, - } - - let cases = [ - Case { - name: "hunk_context_not_found_skipped", - diff: indoc! {" - @@ -1,3 +1,4 @@ - context_not_in_content - +added_line - more_context - final_context - "}, - content: indoc! {" - completely - different - content - "}, - expected_result: indoc! {" - completely - different - content - "}, - }, - Case { - name: "hunk_context_found_applied", - diff: indoc! {" - @@ -1,3 +1,4 @@ - line1 - +inserted - line2 - line3 - "}, - content: indoc! {" - line1 - line2 - line3 - "}, - expected_result: indoc! {" - line1 - inserted - line2 - line3 - "}, - }, - Case { - name: "multiple_hunks_partial_match", - diff: indoc! {" - @@ -1,2 +1,3 @@ - not_found - +skipped - also_not_found - @@ -5,2 +6,3 @@ - line5 - +applied - line6 - "}, - content: indoc! {" - line1 - line2 - line3 - line4 - line5 - line6 - "}, - expected_result: indoc! {" - line1 - line2 - line3 - line4 - line5 - applied - line6 - "}, - }, - Case { - name: "empty_diff", - diff: "", - content: indoc! {" - unchanged - content - "}, - expected_result: indoc! {" - unchanged - content - "}, - }, - ]; - - for case in &cases { - let result = apply_diff_to_string_lenient(case.diff, case.content); - assert_eq!( - result, case.expected_result, - "Test '{}': expected:\n{}\ngot:\n{}", - case.name, case.expected_result, result - ); - } - } - - #[test] - fn test_unicode_reversal_overlap() { - struct Case { - name: &'static str, - original: &'static str, - current: &'static str, - predicted: &'static str, - expected_reversal_chars: usize, - expected_total_chars: usize, - } - - let cases = [ - Case { - name: "unicode_extension_cjk", - original: "", - current: "日", // 1 char - predicted: "日本語", // 3 chars, adds 2 chars - expected_reversal_chars: 0, - expected_total_chars: 2, // "本語" = 2 chars added - }, - Case { - name: "unicode_extension_emoji", - original: "", - current: "🎉", // 1 char - predicted: "🎉🎊🎈", // 3 chars, adds 2 chars - expected_reversal_chars: 0, - expected_total_chars: 2, // "🎊🎈" = 2 chars added - }, - Case { - name: "unicode_deletion_restored", - original: "héllo wörld", // 11 chars - current: "héllo", // 5 chars - predicted: "héllo wörld", // restores " wörld" = 6 chars - expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars - expected_total_chars: 6, - }, - Case { - name: "unicode_addition_reversed", - original: "café", // 4 chars - current: "café latté", // 10 chars, added " latté" = 6 chars - predicted: "café", // removes " latté" - expected_reversal_chars: 6, // 6 chars removed - expected_total_chars: 6, - }, - Case { - name: "mixed_ascii_unicode", - original: "", - current: "test日本", // 6 chars - predicted: "test日本語です", // 9 chars - expected_reversal_chars: 0, - expected_total_chars: 3, // 3 new chars after subsequence normalization - }, - Case { - name: "unicode_replacement_not_subsequence", - original: "", - current: "日本", // 2 chars - predicted: "中国", // 2 chars, different - expected_reversal_chars: 2, // removes "日本" = 2 chars - expected_total_chars: 4, // 2 removed + 2 added - }, - ]; - - for case in &cases { - let overlap = compute_reversal_overlap(case.original, case.current, case.predicted); - assert_eq!( - overlap.chars_reversing_user_edits, case.expected_reversal_chars, - "Test '{}': expected {} reversal chars, got {}", - case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits - ); - assert_eq!( - overlap.total_chars_in_prediction, case.expected_total_chars, - "Test '{}': expected {} total chars, got {}", - case.name, case.expected_total_chars, overlap.total_chars_in_prediction - ); - } - } - - #[test] - fn test_compute_lcs_length() { - assert_eq!(compute_lcs_length("", ""), 0); - assert_eq!(compute_lcs_length("abc", ""), 0); - assert_eq!(compute_lcs_length("", "abc"), 0); - assert_eq!(compute_lcs_length("abc", "abc"), 3); - assert_eq!(compute_lcs_length("abc", "def"), 0); - assert_eq!(compute_lcs_length("abcdef", "ace"), 3); - assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4); - assert_eq!(compute_lcs_length("日本語", "日語"), 2); - } - - #[test] - fn test_compute_prediction_reversal_ratio_full_file() { - let prompt_inputs = make_test_prompt_inputs( - indoc! {" - line1 - user_added - line2 - "}, - vec![Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("src/test.rs")), - old_path: Arc::from(Path::new("src/test.rs")), - diff: indoc! {" - @@ -1,2 +1,3 @@ - line1 - +user_added - line2 - "} - .into(), - predicted: false, - in_open_source_repo: false, - })], - None, - ); - - let predicted = indoc! {" - line1 - line2 - "}; - let ratio = - compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); - - assert!( - ratio > 0.9, - "Expected high reversal ratio when prediction removes user addition, got {}", - ratio - ); - } - - #[test] - fn test_compute_prediction_reversal_ratio_with_excerpt() { - let prompt_inputs = make_test_prompt_inputs( - indoc! {" - line10 - user_added - line11 - "}, - vec![Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("src/test.rs")), - old_path: Arc::from(Path::new("src/test.rs")), - diff: indoc! {" - @@ -10,2 +10,3 @@ - line10 - +user_added - line11 - "} - .into(), - predicted: false, - in_open_source_repo: false, - })], - Some(10), - ); - - let predicted = indoc! {" - line10 - line11 - "}; - let ratio = - compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); - - assert!( - ratio > 0.9, - "Expected high reversal ratio for excerpt-aware computation, got {}", - ratio - ); - } - - #[test] - fn test_compute_prediction_reversal_ratio_no_history() { - let prompt_inputs = make_test_prompt_inputs( - indoc! {" - original content - "}, - vec![], - None, - ); - - let predicted = indoc! {" - completely different - "}; - let ratio = - compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); - - assert_eq!( - ratio, 0.0, - "Expected zero reversal ratio with no edit history" - ); - } - - #[test] - fn test_compute_prediction_reversal_ratio_path_filtering() { - let prompt_inputs = make_test_prompt_inputs( - indoc! {" - line1 - user_added - line2 - "}, - vec![Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("src/other.rs")), - old_path: Arc::from(Path::new("src/other.rs")), - diff: indoc! {" - @@ -1,2 +1,3 @@ - line1 - +user_added - line2 - "} - .into(), - predicted: false, - in_open_source_repo: false, - })], - None, - ); - - let predicted = indoc! {" - line1 - line2 - "}; - let ratio = - compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); - - assert_eq!( - ratio, 0.0, - "Expected zero reversal when edit history is for different file" - ); - } - - #[test] - fn test_compute_prediction_reversal_ratio_lenient_fallback() { - let prompt_inputs = make_test_prompt_inputs( - indoc! {" - actual_line1 - user_added - actual_line2 - "}, - vec![Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("src/test.rs")), - old_path: Arc::from(Path::new("src/test.rs")), - diff: indoc! {" - @@ -1,2 +1,3 @@ - wrong_context - +user_added - more_wrong - "} - .into(), - predicted: false, - in_open_source_repo: false, - })], - None, - ); - - let predicted = indoc! {" - actual_line1 - actual_line2 - "}; - let ratio = - compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); - - assert!( - ratio >= 0.0 && ratio <= 1.0, - "Ratio should be valid even with lenient fallback, got {}", - ratio - ); - } - - #[test] - fn test_excerpt_aware_reversal_error_recovery() { - let diffs = vec![indoc! {" - @@ -1,2 +1,3 @@ - nonexistent_context - +added - more_nonexistent - "}]; - let excerpt_content = indoc! {" - completely - different - content - "}; - let predicted_content = indoc! {" - completely - modified - content - "}; - - let overlap = - compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content); - - assert!( - overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0, - "Should handle failed diff application gracefully" - ); - } - - #[test] - fn test_only_most_recent_edit_tracked() { - let prompt_inputs = make_test_prompt_inputs( - indoc! {" - line1 - first_add - second_add - line2 - "}, - vec![ - Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("src/test.rs")), - old_path: Arc::from(Path::new("src/test.rs")), - diff: indoc! {" - @@ -1,2 +1,3 @@ - line1 - +first_add - line2 - "} - .into(), - predicted: false, - in_open_source_repo: false, - }), - Arc::new(zeta_prompt::Event::BufferChange { - path: Arc::from(Path::new("src/test.rs")), - old_path: Arc::from(Path::new("src/test.rs")), - diff: indoc! {" - @@ -2,2 +2,3 @@ - first_add - +second_add - line2 - "} - .into(), - predicted: false, - in_open_source_repo: false, - }), - ], - None, - ); - - let predicted = indoc! {" - line1 - first_add - line2 - "}; - let ratio = - compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); - - assert!( - ratio > 0.9, - "Expected high reversal ratio when prediction exactly reverses the most recent edit, got {}", - ratio - ); - } + edit_prediction_metrics::compute_prediction_reversal_ratio_from_history( + prompt_inputs.cursor_excerpt.as_ref(), + &prompt_inputs.events, + prompt_inputs.excerpt_start_row, + predicted_content, + cursor_path, + ) } diff --git a/crates/edit_prediction_metrics/Cargo.toml b/crates/edit_prediction_metrics/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..02181ca3d2456df102e5f4f93e7e3b0e0a5d8313 --- /dev/null +++ b/crates/edit_prediction_metrics/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "edit_prediction_metrics" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/edit_prediction_metrics.rs" + +[dependencies] +language.workspace = true +serde.workspace = true +similar = "2.7.0" +tree-sitter.workspace = true +zeta_prompt.workspace = true + +[dev-dependencies] +indoc.workspace = true +pretty_assertions.workspace = true diff --git a/crates/edit_prediction_metrics/LICENSE-GPL b/crates/edit_prediction_metrics/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/edit_prediction_metrics/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs b/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs new file mode 100644 index 0000000000000000000000000000000000000000..4fbaaf71331c285009091c9bd7b16eafdc6d2829 --- /dev/null +++ b/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs @@ -0,0 +1,24 @@ +mod kept_rate; +mod patch_metrics; +mod reversal; +mod tokenize; +mod tree_sitter; + +pub use kept_rate::KeptRateResult; +#[cfg(test)] +pub use kept_rate::TokenAnnotation; +pub use kept_rate::compute_kept_rate; +pub use patch_metrics::ClassificationMetrics; +pub use patch_metrics::Counts; +pub use patch_metrics::DeltaChrFMetrics; +pub use patch_metrics::TokenChangeCounts; +pub use patch_metrics::braces_disbalance; +pub use patch_metrics::count_patch_token_changes; +pub use patch_metrics::delta_chr_f; +pub use patch_metrics::delta_chr_f_beta; +pub use patch_metrics::exact_lines_match; +pub use patch_metrics::extract_changed_lines_from_diff; +pub use patch_metrics::has_isolated_whitespace_changes; +pub use patch_metrics::is_editable_region_correct; +pub use reversal::compute_prediction_reversal_ratio_from_history; +pub use tree_sitter::count_tree_sitter_errors; diff --git a/crates/edit_prediction/src/metrics/kept_rate.rs b/crates/edit_prediction_metrics/src/kept_rate.rs similarity index 99% rename from crates/edit_prediction/src/metrics/kept_rate.rs rename to crates/edit_prediction_metrics/src/kept_rate.rs index 599280f5d9aea7964b9d99ab318356e9f4acfb49..117ab743c2b0ef51e0f31bc97d1b38af3b534a47 100644 --- a/crates/edit_prediction/src/metrics/kept_rate.rs +++ b/crates/edit_prediction_metrics/src/kept_rate.rs @@ -1,9 +1,10 @@ -use crate::metrics::tokenize; +use crate::tokenize::tokenize; +use serde::Serialize; const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512; #[cfg(test)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] pub enum TokenAnnotation { Context, Kept, @@ -11,7 +12,7 @@ pub enum TokenAnnotation { } #[allow(dead_code)] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct KeptRateResult { /// Characters newly introduced by the candidate pub candidate_new_chars: usize, diff --git a/crates/edit_prediction_metrics/src/patch_metrics.rs b/crates/edit_prediction_metrics/src/patch_metrics.rs new file mode 100644 index 0000000000000000000000000000000000000000..9da499796efabc8e1a767dd1b2ed3843b38d06eb --- /dev/null +++ b/crates/edit_prediction_metrics/src/patch_metrics.rs @@ -0,0 +1,1451 @@ +use std::collections::HashMap; + +use crate::tokenize::tokenize; +use serde::Serialize; +use similar::{DiffTag, TextDiff}; + +pub type Counts = HashMap; +type CountsDelta = HashMap; + +/// Context characters needed on each side of a change to capture all affected n-grams +const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1; + +#[derive(Default, Debug, Clone, Serialize)] +pub struct ClassificationMetrics { + pub true_positives: usize, + pub false_positives: usize, + pub false_negatives: usize, +} + +impl ClassificationMetrics { + pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { + let mut true_positives = 0; + let mut false_positives = 0; + let mut false_negatives = 0; + + for (ngram, &expected_count) in expected { + let actual_count = *actual.get(ngram).unwrap_or(&0); + if actual_count > expected_count { + false_positives += actual_count - expected_count; + } else { + false_negatives += expected_count - actual_count; + } + true_positives += expected_count.min(actual_count); + } + + for (ngram, &actual_count) in actual { + if !expected.contains_key(ngram) { + false_positives += actual_count; + } + } + + ClassificationMetrics { + true_positives, + false_positives, + false_negatives, + } + } + + pub fn accumulate(&mut self, other: &ClassificationMetrics) { + self.true_positives += other.true_positives; + self.false_positives += other.false_positives; + self.false_negatives += other.false_negatives; + } + + pub fn precision(&self) -> f64 { + if self.true_positives + self.false_positives == 0 { + 0.0 + } else { + self.true_positives as f64 / (self.true_positives + self.false_positives) as f64 + } + } + + pub fn recall(&self) -> f64 { + if self.true_positives + self.false_negatives == 0 { + 0.0 + } else { + self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64 + } + } + + pub fn f1(&self) -> f64 { + let precision = self.precision(); + let recall = self.recall(); + if precision + recall == 0.0 { + 0.0 + } else { + 2.0 * precision * recall / (precision + recall) + } + } +} + +enum ChrfWhitespace { + /// Preserve whitespace as-is + #[allow(unused)] + Unchanged, + + /// Ignore all whitespace differences + #[allow(unused)] + Ignore, + + /// Collapse whitespace into single spaces + Collapse, +} + +const CHR_F_CHAR_ORDER: usize = 6; +const CHR_F_BETA: f64 = 0.5; +const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse; + +pub fn delta_chr_f_beta() -> f64 { + CHR_F_BETA +} + +#[derive(Default, Debug, Clone, Serialize)] +pub struct DeltaChrFMetrics { + pub score: f64, + pub beta: f64, + pub counts: ClassificationMetrics, + pub precision: f64, + pub recall: f64, +} + +/// Computes delta-chrF metrics that compare two sets of edits. +/// +/// This metric works by: +/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual +/// 2. Comparing these deltas to measure how well actual edits match expected edits +/// +/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match +/// the expected edits. +pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics { + if original == expected && expected == actual { + return DeltaChrFMetrics { + score: 100.0, + beta: CHR_F_BETA, + precision: 1.0, + recall: 1.0, + ..DeltaChrFMetrics::default() + }; + } + + let orig_chars: Vec = filter_whitespace_chars(original); + let exp_chars: Vec = filter_whitespace_chars(expected); + let act_chars: Vec = filter_whitespace_chars(actual); + + // Find the changed regions between original→expected and original→actual + // We only need to compute n-grams on these regions (plus context for boundary n-grams) + let (orig_for_exp, exp_region) = extract_changed_regions(&orig_chars, &exp_chars); + let (orig_for_act, act_region) = extract_changed_regions(&orig_chars, &act_chars); + + let mut total_precision = 0.0; + let mut total_recall = 0.0; + let mut total_counts = ClassificationMetrics::default(); + + for order in 1..=CHR_F_CHAR_ORDER { + let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order); + let exp_ngrams = count_ngrams_from_chars(&exp_region, order); + let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp); + + let orig_ngrams_for_act = count_ngrams_from_chars(&orig_for_act, order); + let act_ngrams = count_ngrams_from_chars(&act_region, order); + let actual_delta = compute_ngram_delta(&act_ngrams, &orig_ngrams_for_act); + + if expected_delta.is_empty() && actual_delta.is_empty() { + total_precision += 1.0; + total_recall += 1.0; + continue; + } + + let expected_counts = ngram_delta_to_counts(&expected_delta); + let actual_counts = ngram_delta_to_counts(&actual_delta); + + let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); + total_precision += counts.precision(); + total_recall += counts.recall(); + total_counts.accumulate(&counts); + } + + let average_precision = total_precision / CHR_F_CHAR_ORDER as f64; + let average_recall = total_recall / CHR_F_CHAR_ORDER as f64; + let score = if average_precision + average_recall == 0.0 { + 0.0 + } else { + (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall + / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall) + * 100.0 + }; + + DeltaChrFMetrics { + score, + beta: CHR_F_BETA, + counts: total_counts, + precision: average_precision, + recall: average_recall, + } +} + +/// Reference implementation of delta-chrF metrics (original, non-optimized version). +/// Used for testing that the optimized version produces identical results. +#[cfg(test)] +fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics { + if original == expected && expected == actual { + return DeltaChrFMetrics { + score: 100.0, + beta: CHR_F_BETA, + precision: 1.0, + recall: 1.0, + ..DeltaChrFMetrics::default() + }; + } + + let original_ngrams = chr_f_ngram_counts(original); + let expected_ngrams = chr_f_ngram_counts(expected); + let actual_ngrams = chr_f_ngram_counts(actual); + + let mut total_precision = 0.0; + let mut total_recall = 0.0; + let mut total_counts = ClassificationMetrics::default(); + + for order in 0..CHR_F_CHAR_ORDER { + let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]); + let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]); + + if expected_delta.is_empty() && actual_delta.is_empty() { + total_precision += 1.0; + total_recall += 1.0; + continue; + } + + let expected_counts = ngram_delta_to_counts(&expected_delta); + let actual_counts = ngram_delta_to_counts(&actual_delta); + + let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); + total_precision += counts.precision(); + total_recall += counts.recall(); + total_counts.accumulate(&counts); + } + + let average_precision = total_precision / CHR_F_CHAR_ORDER as f64; + let average_recall = total_recall / CHR_F_CHAR_ORDER as f64; + let score = if average_precision + average_recall == 0.0 { + 0.0 + } else { + (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall + / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall) + * 100.0 + }; + + DeltaChrFMetrics { + score, + beta: CHR_F_BETA, + counts: total_counts, + precision: average_precision, + recall: average_recall, + } +} + +/// Filter whitespace from a string and return as Vec +fn filter_whitespace_chars(text: &str) -> Vec { + match CHR_F_WHITESPACE { + ChrfWhitespace::Unchanged => text.chars().collect(), + ChrfWhitespace::Ignore => text.chars().filter(|c| !c.is_whitespace()).collect(), + ChrfWhitespace::Collapse => collapse_whitespace(text.chars()), + } +} + +/// Collapse whitespace into single spaces. +/// Newlines and spaces are collapsed separately. +fn collapse_whitespace(chars: impl Iterator) -> Vec { + let mut result = Vec::new(); + let mut last_whitespace = None; + for c in chars { + if c.is_whitespace() && c != '\n' { + if last_whitespace != Some(' ') { + result.push(' '); + last_whitespace = Some(' '); + } + } else if c == '\n' { + if last_whitespace != Some('\n') { + result.push(c); + last_whitespace = Some('\n'); + } + } else { + result.push(c); + last_whitespace = None; + } + } + result +} + +/// Extract only the changed regions between two texts, with context for n-gram boundaries. +/// +/// Returns (original_affected_region, modified_affected_region) as Vec. +/// +/// The key insight: when computing n-gram delta between two nearly-identical texts, +/// n-grams from unchanged regions cancel out. We only need to process: +/// 1. The changed content itself +/// 2. CONTEXT_CHARS (n-1) characters before and after, to capture boundary-crossing n-grams +fn extract_changed_regions(original: &[char], modified: &[char]) -> (Vec, Vec) { + // Find longest common prefix + let prefix_len = original + .iter() + .zip(modified.iter()) + .take_while(|(a, b)| a == b) + .count(); + + // Find longest common suffix (that doesn't overlap with prefix) + let orig_remaining = original.len().saturating_sub(prefix_len); + let mod_remaining = modified.len().saturating_sub(prefix_len); + let max_suffix = orig_remaining.min(mod_remaining); + + let suffix_len = original + .iter() + .rev() + .zip(modified.iter().rev()) + .take(max_suffix) + .take_while(|(a, b)| a == b) + .count(); + + // Calculate the changed region boundaries + let orig_change_start = prefix_len; + let orig_change_end = original.len().saturating_sub(suffix_len); + let mod_change_start = prefix_len; + let mod_change_end = modified.len().saturating_sub(suffix_len); + + // If there's no actual change, return empty regions + if orig_change_start >= orig_change_end && mod_change_start >= mod_change_end { + return (Vec::new(), Vec::new()); + } + + // Expand to include context for n-gram boundaries + let orig_context_start = orig_change_start.saturating_sub(CONTEXT_CHARS); + let orig_context_end = (orig_change_end + CONTEXT_CHARS).min(original.len()); + let mod_context_start = mod_change_start.saturating_sub(CONTEXT_CHARS); + let mod_context_end = (mod_change_end + CONTEXT_CHARS).min(modified.len()); + + let orig_region: Vec = original[orig_context_start..orig_context_end].to_vec(); + let mod_region: Vec = modified[mod_context_start..mod_context_end].to_vec(); + + (orig_region, mod_region) +} + +/// Count n-grams directly from a char slice (avoids String allocation for the full text) +fn count_ngrams_from_chars(chars: &[char], n: usize) -> Counts { + let mut counts = Counts::default(); + + if chars.len() < n { + return counts; + } + + for window in chars.windows(n) { + let ngram: String = window.iter().collect(); + *counts.entry(ngram).or_insert(0) += 1; + } + + counts +} + +#[allow(dead_code)] +fn chr_f_ngram_counts(text: &str) -> Vec { + let text = match CHR_F_WHITESPACE { + ChrfWhitespace::Unchanged => text.to_string(), + ChrfWhitespace::Ignore => text + .chars() + .filter(|c| !c.is_whitespace()) + .collect::(), + ChrfWhitespace::Collapse => collapse_whitespace(text.chars()) + .into_iter() + .collect::(), + }; + + (1..=CHR_F_CHAR_ORDER) + .map(|order| count_ngrams(&text, order)) + .collect() +} + +fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta { + let mut delta = CountsDelta::default(); + + for (ngram, &before_count) in before { + let after_count = *after.get(ngram).unwrap_or(&0); + delta.insert(ngram.clone(), after_count as isize - before_count as isize); + } + + for (ngram, &after_count) in after { + if !before.contains_key(ngram) { + delta.insert(ngram.clone(), after_count as isize); + } + } + + delta +} + +/// Convert negative counts to special deletion tokens. +/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1}, +/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo" +/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive. +fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts { + let mut counts = Counts::default(); + + for (ngram, &delta) in delta { + if delta > 0 { + counts.insert(ngram.clone(), delta as usize); + } else if delta < 0 { + counts.insert(format!("¬{ngram}"), delta.unsigned_abs()); + } + } + + counts +} + +#[allow(dead_code)] +fn count_ngrams(text: &str, n: usize) -> Counts { + let chars: Vec = text.chars().collect(); + let mut counts = Counts::default(); + + for window in chars.windows(n) { + let ngram: String = window.iter().collect(); + *counts.entry(ngram).or_insert(0) += 1; + } + + counts +} + +pub fn braces_disbalance(text: &str) -> usize { + let mut disbalance = 0isize; + + let a = text.chars().filter(|&c| c == '{').count() as isize; + let b = text.chars().filter(|&c| c == '}').count() as isize; + disbalance += (a - b).abs(); + + let a = text.chars().filter(|&c| c == '(').count() as isize; + let b = text.chars().filter(|&c| c == ')').count() as isize; + disbalance += (a - b).abs(); + + let a = text.chars().filter(|&c| c == '[').count() as isize; + let b = text.chars().filter(|&c| c == ']').count() as isize; + disbalance += (a - b).abs(); + + disbalance as usize +} + +/// Extracts changed lines from a unified diff string. +/// Returns a bag (multiset) of lines that were added (+) or removed (-). +/// The +/- prefix is included in the line to distinguish additions from deletions. +pub fn extract_changed_lines_from_diff(diff: &str) -> Counts { + let mut counts = Counts::default(); + + for line in diff.lines() { + // Skip file headers (--- and +++) + if line.starts_with("---") || line.starts_with("+++") { + continue; + } + // Skip hunk headers (@@) + if line.starts_with("@@") { + continue; + } + // Skip diff header lines (diff --git, index, etc.) + if line.starts_with("diff ") || line.starts_with("index ") { + continue; + } + // Include added and removed lines (with their prefix) + if line.starts_with('+') || line.starts_with('-') { + *counts.entry(line.to_string()).or_insert(0) += 1; + } + } + + counts +} + +/// Computes exact lines match metrics between expected and actual patches. +/// Treats changed lines as a bag (multiset) - order is discarded but count matters. +/// Returns ClassificationMetrics with TP/FP/FN counts. +pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics { + let expected_lines = extract_changed_lines_from_diff(expected_patch); + let actual_lines = extract_changed_lines_from_diff(actual_patch); + ClassificationMetrics::from_counts(&expected_lines, &actual_lines) +} + +/// Returns whether the patch contains any isolated whitespace-only changes. +/// +/// A whitespace-only change is an added or deleted line whose content is empty or +/// contains only whitespace. It is "isolated" when it is not adjacent to any +/// substantive (non-whitespace) change within the same contiguous change group. +pub fn has_isolated_whitespace_changes(patch_str: &str, cursor_row: Option) -> bool { + let patch = Patch::parse_unified_diff(patch_str); + + let cursor_new_file_line = cursor_row.map(|row| (row + 1) as usize); + + for hunk in &patch.hunks { + let lines = &hunk.lines; + let mut new_text_line = hunk.new_start as usize; + + for (i, line) in lines.iter().enumerate() { + let content = match line { + PatchLine::Addition(s) => { + let addition_line = new_text_line; + new_text_line += 1; + if s.trim().is_empty() && cursor_new_file_line == Some(addition_line) { + continue; + } + s.as_str() + } + PatchLine::Deletion(s) => s.as_str(), + PatchLine::Context(_) => { + new_text_line += 1; + continue; + } + _ => continue, + }; + + if !content.trim().is_empty() { + continue; + } + + if is_whitespace_change_isolated(lines, i) { + return true; + } + } + } + + false +} + +fn is_whitespace_change_isolated(lines: &[PatchLine], index: usize) -> bool { + // Look backward for a non-whitespace change before hitting a context line + for line in lines[..index].iter().rev() { + match line { + PatchLine::Addition(s) | PatchLine::Deletion(s) => { + if !s.trim().is_empty() { + return false; + } + } + _ => break, + } + } + + // Look forward for a non-whitespace change before hitting a context line + for line in &lines[index + 1..] { + match line { + PatchLine::Addition(s) | PatchLine::Deletion(s) => { + if !s.trim().is_empty() { + return false; + } + } + _ => break, + } + } + + true +} + +/// A simple proxy for whether the prediction respects editable region. +pub fn is_editable_region_correct(actual_patch: &str) -> bool { + // A typical sign of a wrong editable region: a bunch of lines deletion + // at the beginning or end of the patch. + let patch = Patch::parse_unified_diff(actual_patch); + if patch.hunks.is_empty() { + return true; + } + + let hunk = &patch.hunks[0]; + let mut deletions_at_start = 0; + + for line in hunk.lines.iter() { + match line { + PatchLine::Deletion(_) => deletions_at_start += 1, + _ => break, + } + } + + if deletions_at_start >= 3 { + return false; + } + + true +} + +#[derive(Debug, Default, Clone, Serialize)] +pub struct TokenChangeCounts { + pub inserted_tokens: usize, + pub deleted_tokens: usize, +} + +/// Counts the number of inserted and deleted tokens in a unified diff patch. +/// +/// Tokens are words and whitespace sequences (as defined by `word_diff::tokenize`). +/// Within each hunk, the old (`-`) and new (`+`) lines are compared at the token level +/// using an LCS-based diff, so modified lines only count the actually changed tokens +/// rather than the entire line. +pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts { + let mut counts = TokenChangeCounts::default(); + let mut old_lines: Vec<&str> = Vec::new(); + let mut new_lines: Vec<&str> = Vec::new(); + + let flush = + |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, counts: &mut TokenChangeCounts| { + if old_lines.is_empty() && new_lines.is_empty() { + return; + } + + let old_text: String = old_lines + .iter() + .map(|line| if line.len() > 1 { &line[1..] } else { "" }) + .collect::>() + .join("\n"); + + let new_text: String = new_lines + .iter() + .map(|line| if line.len() > 1 { &line[1..] } else { "" }) + .collect::>() + .join("\n"); + + let old_tokens = tokenize(&old_text); + let new_tokens = tokenize(&new_text); + let ops = diff_tokens(&old_tokens, &new_tokens); + + for op in ops { + match op { + DiffOp::Equal(..) => {} + DiffOp::Delete(start, end) => { + counts.deleted_tokens += end - start; + } + DiffOp::Insert(start, end) => { + counts.inserted_tokens += end - start; + } + DiffOp::Replace { + old_start, + old_end, + new_start, + new_end, + } => { + counts.deleted_tokens += old_end - old_start; + counts.inserted_tokens += new_end - new_start; + } + } + } + + old_lines.clear(); + new_lines.clear(); + }; + + for line in patch.lines() { + if line.starts_with("---") + || line.starts_with("+++") + || line.starts_with("@@") + || line.starts_with("diff ") + || line.starts_with("index ") + { + flush(&mut old_lines, &mut new_lines, &mut counts); + } else if line.starts_with('-') { + old_lines.push(line); + } else if line.starts_with('+') { + new_lines.push(line); + } else { + flush(&mut old_lines, &mut new_lines, &mut counts); + } + } + + flush(&mut old_lines, &mut new_lines, &mut counts); + counts +} + +#[allow(dead_code)] +#[derive(Debug)] +enum DiffOp { + Equal(usize, usize), + Delete(usize, usize), + Insert(usize, usize), + Replace { + old_start: usize, + old_end: usize, + new_start: usize, + new_end: usize, + }, +} + +fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec { + let diff = TextDiff::from_slices(old, new); + diff.ops() + .iter() + .map(|op| { + let tag = op.tag(); + let old_range = op.old_range(); + let new_range = op.new_range(); + match tag { + DiffTag::Equal => DiffOp::Equal(old_range.start, old_range.end), + DiffTag::Delete => DiffOp::Delete(old_range.start, old_range.end), + DiffTag::Insert => DiffOp::Insert(new_range.start, new_range.end), + DiffTag::Replace => DiffOp::Replace { + old_start: old_range.start, + old_end: old_range.end, + new_start: new_range.start, + new_end: new_range.end, + }, + } + }) + .collect() +} + +#[derive(Debug, Default, Clone)] +struct Patch { + hunks: Vec, +} + +impl Patch { + fn parse_unified_diff(unified_diff: &str) -> Patch { + let mut current_file = String::new(); + let mut is_filename_inherited = false; + let mut hunk = Hunk::default(); + let mut patch = Patch::default(); + let mut in_header = true; + + for line in unified_diff.lines() { + if line.starts_with("--- ") || line.starts_with("+++") || line.starts_with("@@") { + in_header = false; + } + + if in_header { + continue; + } + + if line.starts_with("@@") { + if !hunk.lines.is_empty() { + patch.hunks.push(hunk); + } + hunk = Hunk::from_header(line, ¤t_file, is_filename_inherited); + is_filename_inherited = true; + } else if let Some(path) = line.strip_prefix("--- ") { + is_filename_inherited = false; + let path = path.trim().strip_prefix("a/").unwrap_or(path); + if path != "/dev/null" { + current_file = path.into(); + } + } else if let Some(path) = line.strip_prefix("+++ ") { + is_filename_inherited = false; + let path = path.trim().strip_prefix("b/").unwrap_or(path); + if path != "/dev/null" { + current_file = path.into(); + } + } else if let Some(line) = line.strip_prefix('+') { + hunk.lines.push(PatchLine::Addition(line.to_string())); + } else if let Some(line) = line.strip_prefix('-') { + hunk.lines.push(PatchLine::Deletion(line.to_string())); + } else if let Some(line) = line.strip_prefix(' ') { + hunk.lines.push(PatchLine::Context(line.to_string())); + } else { + hunk.lines.push(PatchLine::Garbage(line.to_string())); + } + } + + if !hunk.lines.is_empty() { + patch.hunks.push(hunk); + } + + patch + } +} + +#[derive(Debug, Default, Clone)] +struct Hunk { + new_start: isize, + lines: Vec, +} + +impl Hunk { + fn from_header(header: &str, _filename: &str, _is_filename_inherited: bool) -> Self { + let (_, _, new_start, _, _) = Self::parse_hunk_header(header); + Self { + new_start, + lines: Vec::new(), + } + } + + fn parse_hunk_header(line: &str) -> (isize, isize, isize, isize, String) { + let header_part = line.trim_start_matches("@@").trim(); + let parts: Vec<&str> = header_part.split_whitespace().collect(); + + if parts.len() < 2 { + return (0, 0, 0, 0, String::new()); + } + + let old_part = parts[0].trim_start_matches('-'); + let new_part = parts[1].trim_start_matches('+'); + + let (old_start, old_count) = Hunk::parse_hunk_header_range(old_part); + let (new_start, new_count) = Hunk::parse_hunk_header_range(new_part); + + let comment = if parts.len() > 2 { + parts[2..] + .join(" ") + .trim_start_matches("@@") + .trim() + .to_string() + } else { + String::new() + }; + + ( + old_start as isize, + old_count as isize, + new_start as isize, + new_count as isize, + comment, + ) + } + + fn parse_hunk_header_range(part: &str) -> (usize, usize) { + if let Some((start, count)) = part.split_once(',') { + (start.parse().unwrap_or(0), count.parse().unwrap_or(0)) + } else { + (part.parse().unwrap_or(0), 1) + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +enum PatchLine { + Context(String), + Addition(String), + Deletion(String), + Garbage(String), +} + +#[cfg(test)] +mod test_optimization { + use super::*; + + #[test] + fn test_extract_changed_regions_simple() { + let original: Vec = "hello world".chars().collect(); + let modified: Vec = "hello there".chars().collect(); + + let (orig_region, mod_region) = extract_changed_regions(&original, &modified); + + // "world" vs "there" - with 5 chars context, we get "ello world" vs "ello there" + // (or less if not enough chars available) + assert!(orig_region.len() < original.len()); + assert!(mod_region.len() < modified.len()); + } + + #[test] + fn test_extract_changed_regions_insertion() { + let original: Vec = "abcdef".chars().collect(); + let modified: Vec = "abcXYZdef".chars().collect(); + + let (orig_region, mod_region) = extract_changed_regions(&original, &modified); + + // The insertion is between c and d, so we need context around that point + assert!(orig_region.len() <= original.len()); + assert!(mod_region.iter().collect::().contains("XYZ")); + } + + #[test] + fn test_extract_changed_regions_identical() { + let text: Vec = "identical text".chars().collect(); + + let (orig_region, mod_region) = extract_changed_regions(&text, &text); + + // When texts are identical, regions should be empty + assert!(orig_region.is_empty()); + assert!(mod_region.is_empty()); + } + + #[test] + fn test_optimized_matches_original_score() { + // Test that our optimized version produces the same results + let test_cases = vec![ + ("hello world", "hello there", "hello world"), + ( + "fn main() {}", + "fn main() { println!(); }", + "fn main() { print!(); }", + ), + ("abcdefghij", "abcXXXghij", "abcYYghij"), + ("unchanged", "unchanged", "unchanged"), + ( + "prefix middle suffix", + "prefix CHANGED suffix", + "prefix middle suffix", + ), + ]; + + for (original, expected, actual) in test_cases { + let score = delta_chr_f(original, expected, actual).score; + // Just verify it produces a reasonable score (0-100) + assert!( + score >= 0.0 && score <= 100.0, + "Score {} out of range for ({}, {}, {})", + score, + original, + expected, + actual + ); + } + } + + #[test] + fn test_optimized_equals_reference() { + // Comprehensive test that optimized version matches reference implementation exactly + let test_cases = vec![ + // Basic cases + ("hello world", "hello there", "hello world"), + ("hello world", "hello there", "hello there"), + ("unchanged", "unchanged", "unchanged"), + // Code-like cases + ( + "fn main() { println!(\"Hello\"); }", + "fn main() { println!(\"Hello, World!\"); }", + "fn main() { println!(\"Hello, World!\"); }", + ), + ( + "fn main() { println!(\"Hello\"); }", + "fn main() { println!(\"Hello, World!\"); }", + "fn main() { println!(\"Goodbye\"); }", + ), + // Insertion + ("abcdef", "abcXYZdef", "abcdef"), + ("abcdef", "abcXYZdef", "abcXYZdef"), + ("abcdef", "abcXYZdef", "abcABCdef"), + // Deletion + ("abcXYZdef", "abcdef", "abcXYZdef"), + ("abcXYZdef", "abcdef", "abcdef"), + // Multiple changes (simulated by different expected/actual) + ("one two three four", "one THREE four", "one two FOUR"), + // Edge cases + ("a", "b", "c"), + ("", "abc", ""), + ("abc", "", "abc"), + // Longer text with small change + ( + "This is a longer piece of text that contains many words and characters to process", + "This is a longer piece of TEXT that contains many words and characters to process", + "This is a longer piece of text that contains many words and characters to process", + ), + // Change at the beginning + ( + "ORIGINAL start of text", + "NEW start of text", + "DIFFERENT start of text", + ), + // Change at the end + ( + "text ending ORIGINAL", + "text ending NEW", + "text ending DIFFERENT", + ), + // Whitespace (should be ignored) + ("hello world", "hello there", "hello world"), + ("a b c d", "a X c d", "a Y c d"), + ]; + + for (original, expected, actual) in test_cases { + let optimized_metrics = delta_chr_f(original, expected, actual); + let reference_metrics = delta_chr_f_reference(original, expected, actual); + + assert!( + (optimized_metrics.score - reference_metrics.score).abs() < 1e-10, + "Score mismatch for ({:?}, {:?}, {:?}):\n optimized: {}\n reference: {}", + original, + expected, + actual, + optimized_metrics.score, + reference_metrics.score + ); + assert_eq!( + optimized_metrics.counts.true_positives, + reference_metrics.counts.true_positives + ); + assert_eq!( + optimized_metrics.counts.false_positives, + reference_metrics.counts.false_positives + ); + assert_eq!( + optimized_metrics.counts.false_negatives, + reference_metrics.counts.false_negatives + ); + assert!((optimized_metrics.precision - reference_metrics.precision).abs() < 1e-10); + assert!((optimized_metrics.recall - reference_metrics.recall).abs() < 1e-10); + } + } + + #[test] + fn test_delta_chr_f_metrics_include_counts_and_rates() { + let original = "one two three"; + let expected = "one three"; + let actual = "one two four"; + + let metrics = delta_chr_f(original, expected, actual); + + assert!(metrics.score > 20.0 && metrics.score < 40.0); + assert!(metrics.counts.true_positives > 0); + assert!(metrics.counts.false_positives > 0); + assert!(metrics.counts.false_negatives > 0); + assert!(metrics.precision > 0.0 && metrics.precision < 1.0); + assert!(metrics.recall > 0.0 && metrics.recall < 1.0); + assert_eq!(metrics.beta, CHR_F_BETA); + } +} + +#[cfg(test)] +mod test { + use super::*; + use indoc::indoc; + + fn cursor_on_line(one_based_line: u32) -> u32 { + one_based_line - 1 + } + + #[test] + fn test_delta_chr_f_perfect_match() { + let original = "fn main() { println!(\"Hello\");}"; + let expected = "fn main() { println!(\"Hello, World!\");}"; + + let score = delta_chr_f(original, expected, expected).score; + assert!((score - 100.0).abs() < 1e-2); + } + + #[test] + fn test_delta_chr_f_wrong_edit() { + // When the edit is wrong + let original = "one two three"; + let expected = "one three"; // deleted "two " + let actual = "one two four"; // deleted "three", added "four" + + // Then the score should be low + let score = delta_chr_f(original, expected, actual).score; + assert!(score > 20.0 && score < 40.0); + } + + #[test] + fn test_delta_chr_f_partial_match() { + let original = "let x = 42;"; + let expected = "let x = 100;"; + let actual = "let x = 99;"; + + // We got the edit location right, but the replacement text is wrong. + // Deleted ngrams will match, bringing the score somewhere in the middle. + let score = delta_chr_f(original, expected, actual).score; + assert!(score > 40.0 && score < 60.0); + } + + #[test] + fn test_delta_chr_f_missed_edit() { + // When predictions makes no changes + let original = "prefix old suffix"; + let expected = "prefix new suffix"; + let actual = "prefix old suffix"; // no change + + // Then the score should be low (all expected changes are false negatives) + let score = delta_chr_f(original, expected, actual).score; + assert!(score < 20.0); + } + + #[test] + fn test_delta_chr_f_extra_edit() { + // When adding unexpected content + let original = "helloworld"; + let expected = "helloworld"; // no change expected + let actual = "helloextraworld"; // added "extra" + + // Then the score should be low (all actual changes are false positives) + let score = delta_chr_f(original, expected, actual).score; + assert!(score < 20.0); + } + + #[test] + fn test_delta_chr_f_no_changes() { + let text = "unchanged text"; + let score = delta_chr_f(text, text, text).score; + assert!((score - 100.0).abs() < 1e-2); + } + + #[test] + fn test_braces_disbalance() { + let text = "let x = { 1 + 2 };"; + assert_eq!(braces_disbalance(text), 0); + + let text = "let x = { 1 + 2"; + assert_eq!(braces_disbalance(text), 1); + + let text = "let x = { 1 + 2 )"; + assert_eq!(braces_disbalance(text), 2); + } + + #[test] + fn test_extract_changed_lines_from_diff() { + let diff = r#"--- a/file.rs ++++ b/file.rs +@@ -1,3 +1,3 @@ + fn main() { +- println!("hello"); ++ println!("world"); + }"#; + + let counts = extract_changed_lines_from_diff(diff); + assert_eq!(counts.get("- println!(\"hello\");"), Some(&1)); + assert_eq!(counts.get("+ println!(\"world\");"), Some(&1)); + assert_eq!(counts.len(), 2); + } + + #[test] + fn test_extract_changed_lines_skips_headers() { + let diff = r#"diff --git a/file.rs b/file.rs +index abc123..def456 100644 +--- a/file.rs ++++ b/file.rs +@@ -1,2 +1,2 @@ +-old line ++new line"#; + + let counts = extract_changed_lines_from_diff(diff); + assert_eq!(counts.get("-old line"), Some(&1)); + assert_eq!(counts.get("+new line"), Some(&1)); + assert_eq!(counts.len(), 2); + } + + #[test] + fn test_exact_lines_match_perfect() { + let expected = r#"--- a/file.rs ++++ b/file.rs +@@ -1,3 +1,3 @@ +-old line 1 +-old line 2 ++new line 1 ++new line 2"#; + + let actual = r#"--- a/file.rs ++++ b/file.rs +@@ -1,3 +1,3 @@ +-old line 1 +-old line 2 ++new line 1 ++new line 2"#; + + let metrics = exact_lines_match(expected, actual); + assert_eq!(metrics.true_positives, 4); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 0); + assert!((metrics.precision() - 1.0).abs() < 1e-6); + assert!((metrics.recall() - 1.0).abs() < 1e-6); + assert!((metrics.f1() - 1.0).abs() < 1e-6); + } + + #[test] + fn test_exact_lines_match_partial() { + let expected = r#"-old line 1 +-old line 2 ++new line 1 ++new line 2"#; + + let actual = r#"-old line 1 ++new line 1 ++extra line"#; + + let metrics = exact_lines_match(expected, actual); + // TP: "-old line 1" and "+new line 1" (2) + // FP: "+extra line" (1) + // FN: "-old line 2" and "+new line 2" (2) + assert_eq!(metrics.true_positives, 2); + assert_eq!(metrics.false_positives, 1); + assert_eq!(metrics.false_negatives, 2); + } + + #[test] + fn test_exact_lines_match_no_overlap() { + let expected = r#"-line a ++line b"#; + + let actual = r#"-line x ++line y"#; + + let metrics = exact_lines_match(expected, actual); + assert_eq!(metrics.true_positives, 0); + assert_eq!(metrics.false_positives, 2); + assert_eq!(metrics.false_negatives, 2); + assert!((metrics.precision()).abs() < 1e-6); + assert!((metrics.recall()).abs() < 1e-6); + } + + #[test] + fn test_exact_lines_match_duplicate_lines() { + let expected = r#"+line a ++line a ++line a"#; + + let actual = r#"+line a ++line a"#; + + let metrics = exact_lines_match(expected, actual); + // Expected has 3 "+line a", actual has 2 + // TP: 2, FN: 1, FP: 0 + assert_eq!(metrics.true_positives, 2); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 1); + } + + #[test] + fn test_exact_lines_match_empty_patches() { + let metrics = exact_lines_match("", ""); + assert_eq!(metrics.true_positives, 0); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 0); + } + + #[test] + fn test_is_editable_region_correct() { + let patch = indoc! {" + @@ -1,1 +1,1 @@ + -context + -removed + -from the beginning of the file + import sys + +sys.exit(0) + + "}; + assert!(!is_editable_region_correct(patch)); + + let patch = indoc! {" + @@ -1,1 +1,1 @@ + "}; + assert!(is_editable_region_correct(patch)); + } + + #[test] + fn test_isolated_whitespace_purely_whitespace_patch() { + let patch = indoc! {" + @@ -1,3 +1,4 @@ + fn main() { + + + println!(\"hello\"); + } + "}; + assert!(has_isolated_whitespace_changes(patch, None)); + } + + #[test] + fn test_isolated_whitespace_adjacent_to_real_change() { + let patch = indoc! {" + @@ -1,3 +1,4 @@ + fn main() { + + + + let x = 1; + println!(\"hello\"); + } + "}; + assert!(!has_isolated_whitespace_changes(patch, None)); + } + + #[test] + fn test_isolated_whitespace_no_whitespace_changes() { + let patch = indoc! {" + @@ -1,3 +1,3 @@ + fn main() { + - println!(\"hello\"); + + println!(\"world\"); + } + "}; + assert!(!has_isolated_whitespace_changes(patch, None)); + } + + #[test] + fn test_isolated_whitespace_deletion() { + let patch = indoc! {" + @@ -1,4 +1,3 @@ + fn main() { + - + println!(\"hello\"); + } + "}; + assert!(has_isolated_whitespace_changes(patch, None)); + } + + #[test] + fn test_isolated_whitespace_mixed_groups() { + let patch = indoc! {" + @@ -1,7 +1,8 @@ + fn main() { + + + let x = 1; + - let y = 2; + + let y = 3; + + + + println!(\"hello\"); + } + "}; + assert!(has_isolated_whitespace_changes(patch, None)); + } + + #[test] + fn test_isolated_whitespace_empty_patch() { + let patch = ""; + assert!(!has_isolated_whitespace_changes(patch, None)); + } + + #[test] + fn test_isolated_whitespace_skipped_on_cursor_line() { + // The addition of a blank line at new-file line 2 should be skipped + // because the cursor is on that line. + let patch = indoc! {" + @@ -1,3 +1,4 @@ + fn main() { + + + println!(\"hello\"); + } + "}; + // New-file line 2 is the added blank line + let cursor = cursor_on_line(2); + assert!(!has_isolated_whitespace_changes(patch, Some(cursor))); + } + + #[test] + fn test_isolated_whitespace_not_skipped_when_cursor_on_different_line() { + // The blank line is at new-file line 2, but the cursor is on line 1. + let patch = indoc! {" + @@ -1,3 +1,4 @@ + fn main() { + + + println!(\"hello\"); + } + "}; + let cursor = cursor_on_line(1); + assert!(has_isolated_whitespace_changes(patch, Some(cursor))); + } + + #[test] + fn test_isolated_whitespace_deletion_not_skipped_by_cursor() { + // Deletions don't have a new-file line, so cursor can't suppress them. + let patch = indoc! {" + @@ -1,4 +1,3 @@ + fn main() { + - + println!(\"hello\"); + } + "}; + let cursor = cursor_on_line(2); + assert!(has_isolated_whitespace_changes(patch, Some(cursor))); + } + + #[test] + fn test_count_patch_token_changes_real_world_rename() { + // Real-world patch that was reported as returning 0 tokens + let patch = "--- a/sip_call\\README.md\n+++ b/sip_call\\README.md\n@@ -1,1 +1,1 @@\n-# \n+# SIP Call\n"; + let counts = count_patch_token_changes(patch); + // "# " vs "# SIP Call" — the "SIP" and "Call" tokens (and a whitespace token) are inserted + assert!( + counts.inserted_tokens > 0, + "expected inserted tokens > 0, got {}", + counts.inserted_tokens + ); + assert_eq!(counts.deleted_tokens, 0); + } + + #[test] + fn test_count_patch_token_changes_real_world_expansion() { + // Real-world patch: single token expanded to multiple lines + let patch = "--- a/task1/src/app/app.html\n+++ b/task1/src/app/app.html\n@@ -1,7 +1,9 @@\n \n \n
\n \n
\n"; + let counts = count_patch_token_changes(patch); + assert!( + counts.inserted_tokens > 0, + "expected inserted tokens > 0, got {}", + counts.inserted_tokens + ); + assert!( + counts.deleted_tokens > 0, + "expected deleted tokens > 0, got {}", + counts.deleted_tokens + ); + } + + #[test] + fn test_count_patch_token_changes_simple_replacement() { + let patch = indoc! {" + @@ -1,3 +1,3 @@ + fn main() { + - println!(\"hello\"); + + println!(\"world\"); + } + "}; + let counts = count_patch_token_changes(patch); + assert_eq!(counts.deleted_tokens, 1, "deleted: \"hello\""); + assert_eq!(counts.inserted_tokens, 1, "inserted: \"world\""); + } + + #[test] + fn test_count_patch_token_changes_insertion_only() { + let patch = indoc! {" + @@ -1,2 +1,3 @@ + fn main() { + + println!(\"hello\"); + } + "}; + let counts = count_patch_token_changes(patch); + assert_eq!(counts.deleted_tokens, 0); + assert!(counts.inserted_tokens > 0); + } + + #[test] + fn test_count_patch_token_changes_deletion_only() { + let patch = indoc! {" + @@ -1,3 +1,2 @@ + fn main() { + - println!(\"hello\"); + } + "}; + let counts = count_patch_token_changes(patch); + assert!(counts.deleted_tokens > 0); + assert_eq!(counts.inserted_tokens, 0); + } + + #[test] + fn test_count_patch_token_changes_empty_patch() { + let patch = ""; + let counts = count_patch_token_changes(patch); + assert_eq!(counts.deleted_tokens, 0); + assert_eq!(counts.inserted_tokens, 0); + } + + #[test] + fn test_count_patch_token_changes_multiple_hunks() { + let patch = indoc! {" + @@ -1,3 +1,3 @@ + fn main() { + - let x = 1; + + let x = 2; + } + @@ -10,3 +10,3 @@ + fn other() { + - let y = 3; + + let y = 4; + } + "}; + let counts = count_patch_token_changes(patch); + assert_eq!(counts.deleted_tokens, 2, "deleted: \"1\" and \"3\""); + assert_eq!(counts.inserted_tokens, 2, "inserted: \"2\" and \"4\""); + } + + #[test] + fn test_count_patch_token_changes_multiword_change() { + let patch = indoc! {" + @@ -1,1 +1,1 @@ + -hello world foo + +hello bar baz + "}; + let counts = count_patch_token_changes(patch); + // "world" and "foo" deleted, "bar" and "baz" inserted + // (whitespace tokens between them may also count) + assert!(counts.deleted_tokens >= 2); + assert!(counts.inserted_tokens >= 2); + } + + #[test] + fn test_whitespace_collapse() { + let text = "abc \n\n\n 123"; + let collapsed = collapse_whitespace(text.chars()); + assert_eq!( + collapsed, + vec!['a', 'b', 'c', ' ', '\n', ' ', '1', '2', '3'] + ); + } +} diff --git a/crates/edit_prediction_metrics/src/reversal.rs b/crates/edit_prediction_metrics/src/reversal.rs new file mode 100644 index 0000000000000000000000000000000000000000..a1fff663d554f398d900a1b553f4321034a9661f --- /dev/null +++ b/crates/edit_prediction_metrics/src/reversal.rs @@ -0,0 +1,648 @@ +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; + +use language::{char_diff, text_diff}; +use zeta_prompt::udiff::apply_diff_to_string; + +fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String { + let hunks = parse_diff_hunks(diff_str); + let mut result = text.to_string(); + + for hunk in hunks { + let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk)); + if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) { + result = updated; + } + } + + result +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ParsedHunk { + old_start: u32, + old_count: u32, + new_start: u32, + new_count: u32, + lines: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum HunkLine { + Context(String), + Addition(String), + Deletion(String), +} + +fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> { + let line = line.strip_prefix("@@ -")?; + let (old_part, rest) = line.split_once(' ')?; + let rest = rest.strip_prefix('+')?; + let (new_part, _) = rest.split_once(" @@")?; + + let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') { + (start.parse().ok()?, count.parse().ok()?) + } else { + (old_part.parse().ok()?, 1) + }; + + let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') { + (start.parse().ok()?, count.parse().ok()?) + } else { + (new_part.parse().ok()?, 1) + }; + + Some((old_start, old_count, new_start, new_count)) +} + +fn parse_diff_hunks(diff: &str) -> Vec { + let mut hunks = Vec::new(); + let mut current_hunk: Option = None; + + for line in diff.lines() { + if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) { + if let Some(hunk) = current_hunk.take() { + hunks.push(hunk); + } + current_hunk = Some(ParsedHunk { + old_start, + old_count, + new_start, + new_count, + lines: Vec::new(), + }); + } else if let Some(ref mut hunk) = current_hunk { + if let Some(stripped) = line.strip_prefix('+') { + hunk.lines.push(HunkLine::Addition(stripped.to_string())); + } else if let Some(stripped) = line.strip_prefix('-') { + hunk.lines.push(HunkLine::Deletion(stripped.to_string())); + } else if let Some(stripped) = line.strip_prefix(' ') { + hunk.lines.push(HunkLine::Context(stripped.to_string())); + } else if line.is_empty() { + hunk.lines.push(HunkLine::Context(String::new())); + } + } + } + + if let Some(hunk) = current_hunk { + hunks.push(hunk); + } + + hunks +} + +fn format_hunk(hunk: &ParsedHunk) -> String { + let mut result = format!( + "@@ -{},{} +{},{} @@\n", + hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count + ); + for line in &hunk.lines { + match line { + HunkLine::Context(text) => { + result.push(' '); + result.push_str(text); + result.push('\n'); + } + HunkLine::Addition(text) => { + result.push('+'); + result.push_str(text); + result.push('\n'); + } + HunkLine::Deletion(text) => { + result.push('-'); + result.push_str(text); + result.push('\n'); + } + } + } + result +} + +fn filter_diff_hunks_by_excerpt( + diff: &str, + excerpt_start_row: u32, + excerpt_row_count: u32, +) -> (String, i32) { + let hunks = parse_diff_hunks(diff); + let excerpt_start_0based = excerpt_start_row; + let excerpt_end_0based = excerpt_start_row + excerpt_row_count; + + let mut filtered_hunks = Vec::new(); + let mut cumulative_line_offset: i32 = 0; + + for hunk in hunks { + let hunk_start_0based = hunk.new_start.saturating_sub(1); + let hunk_end_0based = hunk_start_0based + hunk.new_count; + + let additions: i32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Addition(_))) + .count() as i32; + let deletions: i32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Deletion(_))) + .count() as i32; + let hunk_line_delta = additions - deletions; + + if hunk_end_0based <= excerpt_start_0based { + cumulative_line_offset += hunk_line_delta; + continue; + } + + if hunk_start_0based >= excerpt_end_0based { + continue; + } + + let mut filtered_lines = Vec::new(); + let mut current_row_0based = hunk_start_0based; + let mut filtered_old_count = 0u32; + let mut filtered_new_count = 0u32; + let mut first_included_row: Option = None; + + for line in &hunk.lines { + match line { + HunkLine::Context(text) => { + if current_row_0based >= excerpt_start_0based + && current_row_0based < excerpt_end_0based + { + if first_included_row.is_none() { + first_included_row = Some(current_row_0based); + } + filtered_lines.push(HunkLine::Context(text.clone())); + filtered_old_count += 1; + filtered_new_count += 1; + } + current_row_0based += 1; + } + HunkLine::Addition(text) => { + if current_row_0based >= excerpt_start_0based + && current_row_0based < excerpt_end_0based + { + if first_included_row.is_none() { + first_included_row = Some(current_row_0based); + } + filtered_lines.push(HunkLine::Addition(text.clone())); + filtered_new_count += 1; + } + current_row_0based += 1; + } + HunkLine::Deletion(text) => { + if current_row_0based >= excerpt_start_0based + && current_row_0based < excerpt_end_0based + { + if first_included_row.is_none() { + first_included_row = Some(current_row_0based); + } + filtered_lines.push(HunkLine::Deletion(text.clone())); + filtered_old_count += 1; + } + } + } + } + + if !filtered_lines.is_empty() { + let first_row = first_included_row.unwrap_or(excerpt_start_0based); + let new_start_1based = (first_row - excerpt_start_0based) + 1; + + filtered_hunks.push(ParsedHunk { + old_start: new_start_1based, + old_count: filtered_old_count, + new_start: new_start_1based, + new_count: filtered_new_count, + lines: filtered_lines, + }); + } + + cumulative_line_offset += hunk_line_delta; + } + + let mut result = String::new(); + for hunk in &filtered_hunks { + result.push_str(&format_hunk(hunk)); + } + + (result, cumulative_line_offset) +} + +fn compute_excerpt_aware_reversal_overlap( + edit_history_diffs: &[&str], + excerpt_content: &str, + excerpt_start_row: u32, + predicted_content: &str, +) -> ReversalOverlap { + let mut current_content = excerpt_content.to_string(); + let mut current_excerpt_start_row = excerpt_start_row; + + for diff in edit_history_diffs.iter().rev() { + if diff.is_empty() { + continue; + } + + let current_row_count = current_content.lines().count() as u32; + let (filtered_diff, _line_offset) = + filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1)); + + if filtered_diff.is_empty() { + let hunks = parse_diff_hunks(diff); + for hunk in hunks { + let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count; + if hunk_end <= current_excerpt_start_row { + let additions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Addition(_))) + .count() as u32; + let deletions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Deletion(_))) + .count() as u32; + if additions >= deletions { + current_excerpt_start_row = + current_excerpt_start_row.saturating_sub(additions - deletions); + } else { + current_excerpt_start_row += deletions - additions; + } + } + } + continue; + } + + let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff)); + match apply_diff_to_string(&reversed, ¤t_content) { + Ok(updated) => { + current_content = updated; + } + Err(_) => { + continue; + } + } + + let hunks = parse_diff_hunks(diff); + for hunk in hunks { + let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count; + if hunk_end <= current_excerpt_start_row { + let additions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Addition(_))) + .count() as u32; + let deletions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Deletion(_))) + .count() as u32; + if additions >= deletions { + current_excerpt_start_row = + current_excerpt_start_row.saturating_sub(additions - deletions); + } else { + current_excerpt_start_row += deletions - additions; + } + } + } + } + + compute_reversal_overlap(¤t_content, excerpt_content, predicted_content) +} + +fn reverse_diff(diff: &str) -> String { + let mut result: String = diff + .lines() + .map(|line| { + if line.starts_with("--- ") { + line.replacen("--- ", "+++ ", 1) + } else if line.starts_with("+++ ") { + line.replacen("+++ ", "--- ", 1) + } else if line.starts_with('+') && !line.starts_with("+++") { + format!("-{}", &line[1..]) + } else if line.starts_with('-') && !line.starts_with("---") { + format!("+{}", &line[1..]) + } else { + line.to_string() + } + }) + .collect::>() + .join("\n"); + if diff.ends_with('\n') { + result.push('\n'); + } + result +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct GranularEdit { + range: Range, + old_text: String, + new_text: String, +} + +fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec { + text_diff(old_text, new_text) + .into_iter() + .map(|(range, new_text)| GranularEdit { + old_text: old_text[range.clone()].to_string(), + range, + new_text: new_text.to_string(), + }) + .collect() +} + +#[derive(Debug, Clone)] +struct HistoryAdditionRange { + range_in_current: Range, +} + +#[derive(Debug, Clone)] +struct HistoryDeletionRange { + deleted_text: String, + position_in_current: usize, +} + +fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec { + let mut result = Vec::new(); + let mut offset_delta: isize = 0; + + for edit in history_edits { + if !edit.new_text.is_empty() { + let new_start = (edit.range.start as isize + offset_delta) as usize; + let new_end = new_start + edit.new_text.len(); + result.push(HistoryAdditionRange { + range_in_current: new_start..new_end, + }); + } + + offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize; + } + + result +} + +fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec { + let mut result = Vec::new(); + let mut offset_delta: isize = 0; + + for edit in history_edits { + if !edit.old_text.is_empty() { + let position_in_current = (edit.range.start as isize + offset_delta) as usize; + result.push(HistoryDeletionRange { + deleted_text: edit.old_text.clone(), + position_in_current, + }); + } + + offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize; + } + + result +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +struct ReversalOverlap { + chars_reversing_user_edits: usize, + total_chars_in_prediction: usize, +} + +impl ReversalOverlap { + fn ratio(&self) -> f32 { + if self.total_chars_in_prediction == 0 { + 0.0 + } else { + self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32 + } + } +} + +/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension), +/// or where `new_text` appears as a subsequence within `old_text` (reduction). +/// +/// For extensions: when the user's text is preserved (in order) within the prediction, +/// we only count the newly inserted characters, not the preserved ones. +/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()") +/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars +/// +/// For reductions: when the prediction's text is preserved (in order) within the original, +/// we only count the deleted characters, not the preserved ones. +/// E.g., "ifrom" → "from" becomes 1 deleted char ("i") +fn normalize_extension_edits(edits: Vec) -> Vec { + edits + .into_iter() + .flat_map(|edit| { + if edit.old_text.is_empty() || edit.new_text.is_empty() { + return vec![edit]; + } + + // Use character-wise diff to find exact byte ranges of changes + let char_edits = char_diff(&edit.old_text, &edit.new_text); + + let all_deletions = !char_edits.is_empty() + && char_edits + .iter() + .all(|(range, replacement)| !range.is_empty() && replacement.is_empty()); + let all_insertions = !char_edits.is_empty() + && char_edits + .iter() + .all(|(range, replacement)| range.is_empty() && !replacement.is_empty()); + if all_deletions || all_insertions { + return char_edits + .into_iter() + .map(|(range, replacement)| GranularEdit { + range: edit.range.start + range.start..edit.range.start + range.end, + old_text: edit.old_text[range].to_string(), + new_text: replacement.to_string(), + }) + .collect(); + } + + // Otherwise, keep the original edit (mixed changes) + vec![edit] + }) + .collect() +} + +fn compute_reversal_overlap( + original_content: &str, + current_content: &str, + predicted_content: &str, +) -> ReversalOverlap { + let history_edits = + normalize_extension_edits(compute_granular_edits(original_content, current_content)); + let prediction_edits = + normalize_extension_edits(compute_granular_edits(current_content, predicted_content)); + + let history_addition_ranges = compute_history_addition_ranges(&history_edits); + let history_deletion_ranges = compute_history_deletion_ranges(&history_edits); + + let reversed_additions = + compute_reversed_additions(&history_addition_ranges, &prediction_edits); + let restored_deletions = + compute_restored_deletions(&history_deletion_ranges, &prediction_edits); + + let total_chars_in_prediction: usize = prediction_edits + .iter() + .map(|e| e.new_text.chars().count() + e.old_text.chars().count()) + .sum(); + + ReversalOverlap { + chars_reversing_user_edits: reversed_additions + restored_deletions, + total_chars_in_prediction, + } +} + +fn compute_reversed_additions( + history_addition_ranges: &[HistoryAdditionRange], + prediction_edits: &[GranularEdit], +) -> usize { + let mut reversed_chars = 0; + + for pred_edit in prediction_edits { + for history_addition in history_addition_ranges { + let overlap_start = pred_edit + .range + .start + .max(history_addition.range_in_current.start); + let overlap_end = pred_edit + .range + .end + .min(history_addition.range_in_current.end); + + if overlap_start < overlap_end { + let relative_start = overlap_start - pred_edit.range.start; + let relative_end = overlap_end - pred_edit.range.start; + let overlap_text = &pred_edit.old_text[relative_start..relative_end]; + reversed_chars += overlap_text.chars().count(); + } + } + } + + reversed_chars +} + +fn compute_restored_deletions( + history_deletion_ranges: &[HistoryDeletionRange], + prediction_edits: &[GranularEdit], +) -> usize { + let mut restored = 0; + + for pred_edit in prediction_edits { + if pred_edit.new_text.is_empty() { + continue; + } + + for deletion in history_deletion_ranges { + if pred_edit.range.contains(&deletion.position_in_current) + || deletion.position_in_current == pred_edit.range.start + { + restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text); + } + } + } + + restored +} + +fn compute_lcs_length(a: &str, b: &str) -> usize { + let a_chars: Vec = a.chars().collect(); + let b_chars: Vec = b.chars().collect(); + let m = a_chars.len(); + let n = b_chars.len(); + + if m == 0 || n == 0 { + return 0; + } + + let mut prev = vec![0; n + 1]; + let mut curr = vec![0; n + 1]; + + for i in 1..=m { + for j in 1..=n { + if a_chars[i - 1] == b_chars[j - 1] { + curr[j] = prev[j - 1] + 1; + } else { + curr[j] = prev[j].max(curr[j - 1]); + } + } + std::mem::swap(&mut prev, &mut curr); + curr.fill(0); + } + + prev[n] +} + +fn filter_edit_history_by_path<'a>( + edit_history: &'a [Arc], + cursor_path: &std::path::Path, +) -> Vec<&'a zeta_prompt::Event> { + edit_history + .iter() + .filter(|event| match event.as_ref() { + zeta_prompt::Event::BufferChange { path, .. } => { + let event_path = path.as_ref(); + if event_path == cursor_path { + return true; + } + let stripped = event_path + .components() + .skip(1) + .collect::(); + stripped == cursor_path + } + }) + .map(|arc| arc.as_ref()) + .collect() +} + +fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str { + match event { + zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(), + } +} + +fn is_predicted_event(event: &zeta_prompt::Event) -> bool { + match event { + zeta_prompt::Event::BufferChange { predicted, .. } => *predicted, + } +} + +pub fn compute_prediction_reversal_ratio_from_history( + current_content: &str, + edit_history: &[Arc], + excerpt_start_row: Option, + predicted_content: &str, + cursor_path: &Path, +) -> f32 { + let relevant_events = filter_edit_history_by_path(edit_history, cursor_path); + + let most_recent = match relevant_events.last() { + Some(event) if !is_predicted_event(event) => *event, + _ => return 0.0, + }; + + let diff = extract_diff_from_event(most_recent); + if diff.is_empty() { + return 0.0; + } + + if let Some(excerpt_start_row) = excerpt_start_row { + let diffs = vec![diff]; + let overlap = compute_excerpt_aware_reversal_overlap( + &diffs, + current_content, + excerpt_start_row, + predicted_content, + ); + return overlap.ratio(); + } + + let reversed = reverse_diff(diff); + let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed); + let original_content = match apply_diff_to_string(&with_headers, current_content) { + Ok(updated_content) => updated_content, + Err(_) => apply_diff_to_string_lenient(&reversed, current_content), + }; + + let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content); + overlap.ratio() +} diff --git a/crates/edit_prediction/src/metrics/tokenize.rs b/crates/edit_prediction_metrics/src/tokenize.rs similarity index 100% rename from crates/edit_prediction/src/metrics/tokenize.rs rename to crates/edit_prediction_metrics/src/tokenize.rs diff --git a/crates/edit_prediction_metrics/src/tree_sitter.rs b/crates/edit_prediction_metrics/src/tree_sitter.rs new file mode 100644 index 0000000000000000000000000000000000000000..96f3135df4cdb7a941161dd603924c72787b31d8 --- /dev/null +++ b/crates/edit_prediction_metrics/src/tree_sitter.rs @@ -0,0 +1,27 @@ +pub fn count_tree_sitter_errors<'a>(nodes: impl Iterator>) -> usize { + let mut total_count: usize = 0; + for node in nodes { + let mut cursor = node.walk(); + 'node: loop { + let current = cursor.node(); + if current.is_error() || current.is_missing() { + total_count += 1; + } + if current.has_error() && cursor.goto_first_child() { + continue; + } + if cursor.goto_next_sibling() { + continue; + } + loop { + if !cursor.goto_parent() { + break 'node; + } + if cursor.goto_next_sibling() { + continue; + } + } + } + } + total_count +}