diff --git a/Cargo.lock b/Cargo.lock index fdd1852e07db1dd62421f076dc4f5a525fca8748..b16363d091696d26016338cd62bbcb5ec8f5a447 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5338,6 +5338,7 @@ dependencies = [ "language", "pretty_assertions", "serde", + "serde_json", "similar", "tree-sitter", "zeta_prompt", diff --git a/crates/edit_prediction_metrics/Cargo.toml b/crates/edit_prediction_metrics/Cargo.toml index 02181ca3d2456df102e5f4f93e7e3b0e0a5d8313..62184f2f6f93da035c3b2196620dd568653df746 100644 --- a/crates/edit_prediction_metrics/Cargo.toml +++ b/crates/edit_prediction_metrics/Cargo.toml @@ -14,6 +14,7 @@ path = "src/edit_prediction_metrics.rs" [dependencies] language.workspace = true serde.workspace = true +serde_json = "1.0" similar = "2.7.0" tree-sitter.workspace = true zeta_prompt.workspace = true diff --git a/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs b/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs index 4fbaaf71331c285009091c9bd7b16eafdc6d2829..3afe02fd083076d84eba0c1cd359a272b08525c0 100644 --- a/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs +++ b/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs @@ -4,9 +4,10 @@ mod reversal; mod tokenize; mod tree_sitter; +pub use kept_rate::AnnotatedToken; pub use kept_rate::KeptRateResult; -#[cfg(test)] pub use kept_rate::TokenAnnotation; +pub use kept_rate::annotate_kept_rate_tokens; pub use kept_rate::compute_kept_rate; pub use patch_metrics::ClassificationMetrics; pub use patch_metrics::Counts; @@ -20,5 +21,6 @@ pub use patch_metrics::exact_lines_match; pub use patch_metrics::extract_changed_lines_from_diff; pub use patch_metrics::has_isolated_whitespace_changes; pub use patch_metrics::is_editable_region_correct; +pub use patch_metrics::reconstruct_texts_from_diff; pub use reversal::compute_prediction_reversal_ratio_from_history; pub use tree_sitter::count_tree_sitter_errors; diff --git a/crates/edit_prediction_metrics/src/kept_rate.rs b/crates/edit_prediction_metrics/src/kept_rate.rs index 117ab743c2b0ef51e0f31bc97d1b38af3b534a47..262c0f85c14e7fe228e39546cf0f9d1b686d027b 100644 --- a/crates/edit_prediction_metrics/src/kept_rate.rs +++ b/crates/edit_prediction_metrics/src/kept_rate.rs @@ -3,14 +3,20 @@ use serde::Serialize; const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512; -#[cfg(test)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename_all = "snake_case")] pub enum TokenAnnotation { Context, Kept, Discarded, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct AnnotatedToken { + pub token: String, + pub annotation: TokenAnnotation, +} + #[allow(dead_code)] #[derive(Debug, Clone, Serialize)] pub struct KeptRateResult { @@ -40,8 +46,7 @@ pub struct KeptRateResult { /// This includes both kept newly introduced characters and correctly /// deleted base characters. pub recall_rate: f64, - /// Per-token classification for candidate tokens used by tests. - #[cfg(test)] + /// Per-token classification for candidate tokens. pub token_annotations: Vec, } @@ -51,9 +56,9 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize { /// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side /// while sharing a single DP table construction. -fn fill_lcs_keep_masks( - a: &[&str], - b: &[&str], +fn fill_lcs_keep_masks( + a: &[T], + b: &[T], mut keep_a: Option<&mut [bool]>, mut keep_b: Option<&mut [bool]>, ) { @@ -124,10 +129,10 @@ fn fill_lcs_keep_masks( let mut dp = vec![0u32; row_count * column_count]; for i in 1..row_count { - let token_a = a_mid[i - 1]; + let token_a = &a_mid[i - 1]; for j in 1..column_count { let index = dp_index(column_count, i, j); - if token_a == b_mid[j - 1] { + if token_a == &b_mid[j - 1] { dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1; } else { let up = dp[dp_index(column_count, i - 1, j)]; @@ -180,41 +185,91 @@ fn fill_lcs_keep_masks( } } -fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec { +fn lcs_keep_mask(a: &[T], b: &[T]) -> Vec { let mut keep_a = vec![false; a.len()]; fill_lcs_keep_masks(a, b, Some(&mut keep_a), None); keep_a } -fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec, Vec) { +fn lcs_keep_masks(a: &[T], b: &[T]) -> (Vec, Vec) { let mut keep_a = vec![false; a.len()]; let mut keep_b = vec![false; b.len()]; fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b)); (keep_a, keep_b) } -fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) { - let mut unmasked_tokens = Vec::with_capacity(tokens.len()); +#[derive(Debug, Clone)] +struct ComparisonUnit { + text: String, + token_start: usize, + token_end: usize, +} + +fn is_identifier_token(token: &str) -> bool { + !token.is_empty() + && token + .chars() + .all(|character| character.is_alphanumeric() || character == '_') +} + +fn build_comparison_units(tokens: &[&str]) -> Vec { + let mut units = Vec::new(); + let mut index = 0; + + while index < tokens.len() { + let token_start = index; + + if is_identifier_token(tokens[index]) { + let mut text = String::new(); + + while index < tokens.len() && is_identifier_token(tokens[index]) { + text.push_str(tokens[index]); + index += 1; + } + + units.push(ComparisonUnit { + text, + token_start, + token_end: index, + }); + } else { + units.push(ComparisonUnit { + text: tokens[index].to_string(), + token_start, + token_end: index + 1, + }); + index += 1; + } + } + + units +} + +fn analyze_masked_units<'a>( + units: &'a [ComparisonUnit], + mask: &[bool], +) -> (Vec<&'a str>, usize, usize) { + let mut unmasked_units = Vec::with_capacity(units.len()); let mut unmasked_chars = 0; let mut masked_chars = 0; - for (&token, &is_masked) in tokens.iter().zip(mask.iter()) { + for (unit, &is_masked) in units.iter().zip(mask.iter()) { if is_masked { - masked_chars += token.len(); + masked_chars += unit.text.len(); } else { - unmasked_tokens.push(token); - unmasked_chars += token.len(); + unmasked_units.push(unit.text.as_str()); + unmasked_chars += unit.text.len(); } } - (unmasked_tokens, unmasked_chars, masked_chars) + (unmasked_units, unmasked_chars, masked_chars) } -fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize { - tokens +fn count_unmasked_unit_chars(units: &[ComparisonUnit], mask: &[bool]) -> usize { + units .iter() .zip(mask.iter()) - .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len())) + .filter_map(|(unit, &is_masked)| (!is_masked).then_some(unit.text.len())) .sum() } @@ -239,7 +294,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa context_chars, kept_rate: 1.0, recall_rate: 1.0, - #[cfg(test)] token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()], }; } @@ -258,7 +312,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa context_chars: 0, kept_rate: 0.0, recall_rate: 0.0, - #[cfg(test)] token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()], }; } @@ -267,29 +320,29 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa let candidate_tokens = tokenize(candidate); let reference_tokens = tokenize(reference); - let (candidate_base_mask, base_candidate_mask) = - lcs_keep_masks(&candidate_tokens, &base_tokens); - let (candidate_reference_mask, reference_candidate_mask) = - lcs_keep_masks(&candidate_tokens, &reference_tokens); - let context_mask: Vec = candidate_base_mask + let candidate_units = build_comparison_units(&candidate_tokens); + let base_units = build_comparison_units(&base_tokens); + let reference_units = build_comparison_units(&reference_tokens); + + let candidate_unit_texts: Vec<&str> = candidate_units .iter() - .zip(candidate_reference_mask.iter()) - .map(|(&in_base, &in_reference)| in_base && in_reference) + .map(|unit| unit.text.as_str()) + .collect(); + let base_unit_texts: Vec<&str> = base_units.iter().map(|unit| unit.text.as_str()).collect(); + let reference_unit_texts: Vec<&str> = reference_units + .iter() + .map(|unit| unit.text.as_str()) .collect(); + let (candidate_base_mask, base_candidate_mask) = + lcs_keep_masks(&candidate_unit_texts, &base_unit_texts); let (stripped_candidate, candidate_new_chars, context_chars) = - analyze_masked_tokens(&candidate_tokens, &context_mask); + analyze_masked_units(&candidate_units, &candidate_base_mask); let (reference_base_mask, base_reference_mask) = - lcs_keep_masks(&reference_tokens, &base_tokens); - let reference_context_mask: Vec = reference_base_mask - .iter() - .zip(reference_candidate_mask.iter()) - .map(|(&in_base, &in_candidate)| in_base && in_candidate) - .collect(); - + lcs_keep_masks(&reference_unit_texts, &base_unit_texts); let (stripped_reference, reference_new_chars, _) = - analyze_masked_tokens(&reference_tokens, &reference_context_mask); + analyze_masked_units(&reference_units, &reference_base_mask); let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference); @@ -299,13 +352,13 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len())) .sum(); - let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask); - let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask); - let correctly_deleted_chars: usize = base_tokens + let candidate_deleted_chars = count_unmasked_unit_chars(&base_units, &base_candidate_mask); + let reference_deleted_chars = count_unmasked_unit_chars(&base_units, &base_reference_mask); + let correctly_deleted_chars: usize = base_units .iter() .zip(base_candidate_mask.iter().zip(base_reference_mask.iter())) - .filter_map(|(&token, (&in_candidate, &in_reference))| { - (!in_candidate && !in_reference).then_some(token.len()) + .filter_map(|(unit, (&in_candidate, &in_reference))| { + (!in_candidate && !in_reference).then_some(unit.text.len()) }) .sum(); @@ -326,24 +379,28 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa matched_edit_chars as f64 / reference_edit_chars as f64 }; - #[cfg(test)] let token_annotations = { - let mut token_annotations = Vec::with_capacity(candidate_tokens.len()); + let mut token_annotations = vec![TokenAnnotation::Context; candidate_tokens.len()]; let mut new_index = 0; - for (token_index, _token) in candidate_tokens.iter().enumerate() { - if context_mask[token_index] { - token_annotations.push(TokenAnnotation::Context); + + for (unit_index, unit) in candidate_units.iter().enumerate() { + let annotation = if candidate_base_mask[unit_index] { + TokenAnnotation::Context } else { let annotation = if keep_mask[new_index] { TokenAnnotation::Kept } else { TokenAnnotation::Discarded }; - #[cfg(test)] - token_annotations.push(annotation); new_index += 1; + annotation + }; + + for token_index in unit.token_start..unit.token_end { + token_annotations[token_index] = annotation; } } + token_annotations }; @@ -358,14 +415,30 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa context_chars, kept_rate, recall_rate, - #[cfg(test)] token_annotations, } } +pub fn annotate_kept_rate_tokens( + base: &str, + candidate: &str, + reference: &str, +) -> Vec { + let result = compute_kept_rate(base, candidate, reference); + tokenize(candidate) + .into_iter() + .zip(result.token_annotations) + .map(|(token, annotation)| AnnotatedToken { + token: token.to_string(), + annotation, + }) + .collect() +} + #[cfg(test)] mod test_kept_rate { use super::*; + use indoc::indoc; #[test] fn test_lcs_keep_masks() { @@ -439,16 +512,24 @@ mod test_kept_rate { #[test] fn test_missing_deletion() { - let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; - let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\neprintln!(\"\");\n"; - let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let base = indoc! {" + fn example() { + epr + "}; + let candidate = indoc! {r#" + fn example() { + epr + eprintln!(""); + "#}; + let reference = indoc! {r#" + fn example() { + eprintln!(""); + "#}; + let result = compute_kept_rate(base, candidate, reference); - assert!( - result.kept_rate < 0.85, - "expected kept_rate < 0.85, got {}", - result.kept_rate - ); - assert!(result.discarded_chars > 0); + assert!((result.kept_rate - (14.0 / 15.0)).abs() < 1e-6); + assert_eq!(result.kept_chars, 14); + assert_eq!(result.discarded_chars, 1); } #[test] @@ -472,8 +553,17 @@ mod test_kept_rate { #[test] fn test_bails_for_dirty_final() { - let base = "fn example() {\n work();\n}\n"; - let candidate = "fn example() {\n work();\n predicted();\n}\n"; + let base = indoc! {" + fn example() { + work(); + } + "}; + let candidate = indoc! {" + fn example() { + work(); + predicted(); + } + "}; let reference = format!( "fn example() {{\n work();\n {}\n}}\n", "settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64) @@ -488,9 +578,19 @@ mod test_kept_rate { #[test] fn test_eprintln_token_alignment() { - let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; - let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; - let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let base = indoc! {" + fn example() { + epr + "}; + let candidate = indoc! {r#" + fn example() { + eprintln!("hello world!"); + "#}; + let reference = indoc! {r#" + fn example() { + eprintln!(""); + "#}; + let result = compute_kept_rate(base, candidate, reference); assert!(result.discarded_chars > 0); assert!(result.kept_chars > 0); @@ -499,6 +599,42 @@ mod test_kept_rate { assert_eq!(result.discarded_chars, 12); } + #[test] + fn test_kept_rate_treats_unchanged_stale_text_as_context() { + let base = indoc! {" + a=fomr + b=old + "}; + let candidate = indoc! {" + a=formula; + b=old + "}; + let reference = indoc! {" + a=formula; + b=new + "}; + + let result = compute_kept_rate(base, candidate, reference); + let candidate_tokens = tokenize(candidate); + + assert_eq!(result.candidate_new_chars, "formula".len() + ";".len()); + assert_eq!(result.kept_chars, "formula".len() + ";".len()); + assert_eq!(result.discarded_chars, 0); + assert_eq!(result.candidate_deleted_chars, "fomr".len()); + assert_eq!(result.correctly_deleted_chars, "fomr".len()); + assert!((result.kept_rate - 1.0).abs() < 1e-6); + assert!((result.recall_rate - (2.0 / 3.0)).abs() < 1e-6); + + let old_index = candidate_tokens + .iter() + .position(|&token| token == "old") + .expect("old token not found"); + assert_eq!( + result.token_annotations[old_index], + TokenAnnotation::Context + ); + } + #[test] fn test_annotations_rename() { let base = " foo(old_name)\n"; @@ -514,7 +650,7 @@ mod test_kept_rate { assert_eq!(result.token_annotations.len(), tokenize(candidate).len()); for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) { - if token == "new_name" { + if matches!(token, "new" | "_" | "name") { assert_eq!(annotation, TokenAnnotation::Kept); } else { assert_eq!(annotation, TokenAnnotation::Context); @@ -524,9 +660,18 @@ mod test_kept_rate { #[test] fn test_annotations_eprintln_coloring() { - let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; - let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; - let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let base = indoc! {" + fn example() { + epr + "}; + let candidate = indoc! {r#" + fn example() { + eprintln!("hello world!"); + "#}; + let reference = indoc! {r#" + fn example() { + eprintln!(""); + "#}; let result = compute_kept_rate(base, candidate, reference); let candidate_tokens = tokenize(candidate); diff --git a/crates/edit_prediction_metrics/src/main.rs b/crates/edit_prediction_metrics/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..0e557c35e7ff1f73b9c14091cc84b126a2d7b82b --- /dev/null +++ b/crates/edit_prediction_metrics/src/main.rs @@ -0,0 +1,710 @@ +use std::env; +use std::fmt::Write as _; +use std::fs; +use std::path::Path; +use std::process; + +use edit_prediction_metrics::{ + ClassificationMetrics, DeltaChrFMetrics, KeptRateResult, TokenAnnotation, + annotate_kept_rate_tokens, braces_disbalance, compute_kept_rate, count_patch_token_changes, + delta_chr_f, exact_lines_match, extract_changed_lines_from_diff, + has_isolated_whitespace_changes, is_editable_region_correct, +}; +use serde::Deserialize; + +fn main() { + if let Err(error) = run() { + eprintln!("error: {error}"); + process::exit(1); + } +} + +fn run() -> Result<(), String> { + let args: Vec = env::args().skip(1).collect(); + if args.is_empty() { + print_usage(); + return Err("missing arguments".to_string()); + } + + let input = CliInput::parse(&args)?; + let report = match input { + CliInput::Files { + base_path, + expected_patch_path, + actual_patch_path, + } => { + let base = fs::read_to_string(&base_path) + .map_err(|err| format!("failed to read {}: {err}", base_path.display()))?; + let expected_patch = fs::read_to_string(&expected_patch_path).map_err(|err| { + format!("failed to read {}: {err}", expected_patch_path.display()) + })?; + let actual_patch = fs::read_to_string(&actual_patch_path) + .map_err(|err| format!("failed to read {}: {err}", actual_patch_path.display()))?; + + let expected = apply_patch_to_excerpt(&base, &expected_patch, 0)?; + let actual = apply_patch_to_excerpt(&base, &actual_patch, 0)?; + + EvaluationReport::new(base, expected_patch, actual_patch, expected, actual) + } + CliInput::Json { + json_path, + prediction_index, + } => { + let json = fs::read_to_string(&json_path) + .map_err(|err| format!("failed to read {}: {err}", json_path.display()))?; + let example: JsonExample = serde_json::from_str(&json) + .map_err(|err| format!("failed to parse {}: {err}", json_path.display()))?; + + let base = example.prompt_inputs.cursor_excerpt; + let excerpt_start_row = example.prompt_inputs.excerpt_start_row; + let expected_patch = example + .expected_patches + .into_iter() + .next() + .ok_or_else(|| "JSON input is missing expected_patches[0]".to_string())?; + let actual_patch = example + .predictions + .into_iter() + .nth(prediction_index) + .ok_or_else(|| { + format!("JSON input does not contain predictions[{prediction_index}]") + })? + .actual_patch; + + let expected = apply_patch_to_excerpt(&base, &expected_patch, excerpt_start_row)?; + let actual = apply_patch_to_excerpt(&base, &actual_patch, excerpt_start_row)?; + + EvaluationReport::new(base, expected_patch, actual_patch, expected, actual) + } + }; + + print_report(&report); + Ok(()) +} + +fn print_usage() { + eprintln!( + "Usage:\n edit_prediction_metrics --base --expected-patch --actual-patch \n edit_prediction_metrics --json [--prediction-index ]" + ); +} + +enum CliInput { + Files { + base_path: std::path::PathBuf, + expected_patch_path: std::path::PathBuf, + actual_patch_path: std::path::PathBuf, + }, + Json { + json_path: std::path::PathBuf, + prediction_index: usize, + }, +} + +impl CliInput { + fn parse(args: &[String]) -> Result { + let mut base_path = None; + let mut expected_patch_path = None; + let mut actual_patch_path = None; + let mut json_path = None; + let mut prediction_index = 0usize; + + let mut index = 0; + while index < args.len() { + match args[index].as_str() { + "--base" => { + index += 1; + base_path = Some(path_arg(args, index, "--base")?); + } + "--expected-patch" => { + index += 1; + expected_patch_path = Some(path_arg(args, index, "--expected-patch")?); + } + "--actual-patch" => { + index += 1; + actual_patch_path = Some(path_arg(args, index, "--actual-patch")?); + } + "--json" => { + index += 1; + json_path = Some(path_arg(args, index, "--json")?); + } + "--prediction-index" => { + index += 1; + let raw = string_arg(args, index, "--prediction-index")?; + prediction_index = raw.parse::().map_err(|err| { + format!("invalid value for --prediction-index ({raw}): {err}") + })?; + } + "--help" | "-h" => { + print_usage(); + process::exit(0); + } + unknown => { + return Err(format!("unrecognized argument: {unknown}")); + } + } + index += 1; + } + + if let Some(json_path) = json_path { + if base_path.is_some() || expected_patch_path.is_some() || actual_patch_path.is_some() { + return Err( + "--json cannot be combined with --base/--expected-patch/--actual-patch" + .to_string(), + ); + } + return Ok(CliInput::Json { + json_path, + prediction_index, + }); + } + + match (base_path, expected_patch_path, actual_patch_path) { + (Some(base_path), Some(expected_patch_path), Some(actual_patch_path)) => { + Ok(CliInput::Files { + base_path, + expected_patch_path, + actual_patch_path, + }) + } + _ => Err( + "expected either --json or all of --base, --expected-patch, and --actual-patch" + .to_string(), + ), + } + } +} + +fn path_arg(args: &[String], index: usize, flag: &str) -> Result { + Ok(Path::new(string_arg(args, index, flag)?).to_path_buf()) +} + +fn string_arg<'a>(args: &'a [String], index: usize, flag: &str) -> Result<&'a str, String> { + args.get(index) + .map(|value| value.as_str()) + .ok_or_else(|| format!("missing value for {flag}")) +} + +#[derive(Debug)] +struct EvaluationReport { + base: String, + expected: String, + actual: String, + kept_rate: KeptRateResult, + exact_lines: ClassificationMetrics, + delta_chr_f: DeltaChrFMetrics, + expected_changed_lines: usize, + actual_changed_lines: usize, + token_changes: edit_prediction_metrics::TokenChangeCounts, + isolated_whitespace_changes: bool, + editable_region_correct: bool, + expected_braces_disbalance: usize, + actual_braces_disbalance: usize, +} + +impl EvaluationReport { + fn new( + base: String, + expected_patch: String, + actual_patch: String, + expected: String, + actual: String, + ) -> Self { + let kept_rate = compute_kept_rate(&base, &actual, &expected); + let exact_lines = exact_lines_match(&expected_patch, &actual_patch); + let delta_chr_f = delta_chr_f(&base, &expected, &actual); + let expected_changed_lines = extract_changed_lines_from_diff(&expected_patch) + .values() + .sum(); + let actual_changed_lines = extract_changed_lines_from_diff(&actual_patch) + .values() + .sum(); + let token_changes = count_patch_token_changes(&actual_patch); + let isolated_whitespace_changes = has_isolated_whitespace_changes(&actual_patch, None); + let editable_region_correct = is_editable_region_correct(&actual_patch); + let expected_braces_disbalance = braces_disbalance(&expected); + let actual_braces_disbalance = braces_disbalance(&actual); + + Self { + base, + expected, + actual, + kept_rate, + exact_lines, + delta_chr_f, + expected_changed_lines, + actual_changed_lines, + token_changes, + isolated_whitespace_changes, + editable_region_correct, + expected_braces_disbalance, + actual_braces_disbalance, + } + } +} + +fn print_report(report: &EvaluationReport) { + println!("Metrics"); + println!("======="); + println!("kept_rate: {:.6}", report.kept_rate.kept_rate); + println!("kept_rate_recall: {:.6}", report.kept_rate.recall_rate); + println!("delta_chr_f: {:.6}", report.delta_chr_f.score); + println!("delta_chr_f_precision: {:.6}", report.delta_chr_f.precision); + println!("delta_chr_f_recall: {:.6}", report.delta_chr_f.recall); + println!("delta_chr_f_beta: {:.6}", report.delta_chr_f.beta); + println!(); + + println!("Exact line match"); + println!("----------------"); + println!("true_positives: {}", report.exact_lines.true_positives); + println!("false_positives: {}", report.exact_lines.false_positives); + println!("false_negatives: {}", report.exact_lines.false_negatives); + println!("precision: {:.6}", report.exact_lines.precision()); + println!("recall: {:.6}", report.exact_lines.recall()); + println!("f1: {:.6}", report.exact_lines.f1()); + println!("expected_changed_lines: {}", report.expected_changed_lines); + println!("actual_changed_lines: {}", report.actual_changed_lines); + println!(); + + println!("Patch structure"); + println!("---------------"); + println!("inserted_tokens: {}", report.token_changes.inserted_tokens); + println!("deleted_tokens: {}", report.token_changes.deleted_tokens); + println!( + "isolated_whitespace_changes: {}", + report.isolated_whitespace_changes + ); + println!( + "editable_region_correct: {}", + report.editable_region_correct + ); + println!(); + + println!("Final text checks"); + println!("-----------------"); + println!( + "expected_braces_disbalance: {}", + report.expected_braces_disbalance + ); + println!( + "actual_braces_disbalance: {}", + report.actual_braces_disbalance + ); + println!(); + + println!("Kept-rate breakdown"); + println!("-------------------"); + println!( + "candidate_new_chars: {}", + report.kept_rate.candidate_new_chars + ); + println!( + "reference_new_chars: {}", + report.kept_rate.reference_new_chars + ); + println!( + "candidate_deleted_chars: {}", + report.kept_rate.candidate_deleted_chars + ); + println!( + "reference_deleted_chars: {}", + report.kept_rate.reference_deleted_chars + ); + println!("kept_chars: {}", report.kept_rate.kept_chars); + println!( + "correctly_deleted_chars: {}", + report.kept_rate.correctly_deleted_chars + ); + println!("discarded_chars: {}", report.kept_rate.discarded_chars); + println!("context_chars: {}", report.kept_rate.context_chars); + println!(); + + print_kept_rate_explanation(&report.base, &report.actual, &report.expected); +} + +fn print_kept_rate_explanation(base: &str, actual: &str, expected: &str) { + println!("Kept-rate explanation"); + println!("---------------------"); + println!("Legend: context = default, kept = green background, discarded = red background"); + println!(); + + let annotated = annotate_kept_rate_tokens(base, actual, expected); + println!("Actual final text with token annotations:"); + println!("{}", render_annotated_tokens(&annotated)); + println!(); +} + +fn render_annotated_tokens(tokens: &[edit_prediction_metrics::AnnotatedToken]) -> String { + const RESET: &str = "\x1b[0m"; + const KEPT_STYLE: &str = "\x1b[30;42m"; + const DISCARDED_STYLE: &str = "\x1b[30;41m"; + + let mut rendered = String::new(); + for token in tokens { + let style = match token.annotation { + TokenAnnotation::Context => "", + TokenAnnotation::Kept => KEPT_STYLE, + TokenAnnotation::Discarded => DISCARDED_STYLE, + }; + + if style.is_empty() { + rendered.push_str(&visualize_whitespace(&token.token)); + } else { + rendered.push_str(style); + rendered.push_str(&visualize_whitespace(&token.token)); + rendered.push_str(RESET); + } + } + rendered +} + +fn visualize_whitespace(token: &str) -> String { + let mut rendered = String::new(); + for ch in token.chars() { + match ch { + ' ' => rendered.push('·'), + '\t' => rendered.push('⇥'), + '\n' => rendered.push_str("↵\n"), + _ => rendered.push(ch), + } + } + rendered +} + +#[derive(Debug, Deserialize)] +struct JsonExample { + prompt_inputs: PromptInputs, + expected_patches: Vec, + predictions: Vec, +} + +#[derive(Debug, Deserialize)] +struct PromptInputs { + cursor_excerpt: String, + excerpt_start_row: u32, +} + +#[derive(Debug, Deserialize)] +struct Prediction { + actual_patch: String, +} + +#[derive(Debug, Clone)] +struct ParsedHunk { + old_start: u32, + lines: Vec, +} + +#[derive(Debug, Clone)] +enum HunkLine { + Context(String), + Addition(String), + Deletion(String), +} + +fn apply_patch_to_excerpt( + base: &str, + patch: &str, + excerpt_start_row: u32, +) -> Result { + let hunks = parse_diff_hunks(patch); + + let result = try_apply_hunks(base, &hunks, excerpt_start_row); + + // Predicted patches may use excerpt-relative line numbers instead of + // file-global ones. When all hunks fall outside the excerpt window the + // result is identical to the base text. Retry with a zero offset so the + // line numbers are interpreted relative to the excerpt. + if excerpt_start_row > 0 && !hunks.is_empty() { + let should_retry = match &result { + Ok(text) => text == base, + Err(_) => true, + }; + + if should_retry { + let fallback = try_apply_hunks(base, &hunks, 0); + if matches!(&fallback, Ok(text) if text != base) { + return fallback; + } + } + } + + result +} + +fn try_apply_hunks( + base: &str, + hunks: &[ParsedHunk], + excerpt_start_row: u32, +) -> Result { + let base_has_trailing_newline = base.ends_with('\n'); + let mut lines = split_preserving_final_empty_line(base); + let original_line_count = lines.len() as u32; + + let excerpt_end_row = excerpt_start_row + original_line_count; + let mut line_delta: i64 = 0; + + for hunk in hunks { + let filtered = match filter_hunk_to_excerpt(hunk, excerpt_start_row, excerpt_end_row) { + Some(filtered) => filtered, + None => continue, + }; + + let local_start = filtered.old_start.saturating_sub(excerpt_start_row) as i64 + line_delta; + if local_start < 0 { + return Err(format!( + "patch application moved before excerpt start at source row {}", + filtered.old_start + )); + } + let local_start = local_start as usize; + + if local_start > lines.len() { + return Err(format!( + "patch application starts past excerpt end at local line {}", + local_start + 1 + )); + } + + let old_len = filtered + .lines + .iter() + .filter(|line| !matches!(line, HunkLine::Addition(_))) + .count(); + let new_len = filtered + .lines + .iter() + .filter(|line| !matches!(line, HunkLine::Deletion(_))) + .count(); + + let old_segment: Vec<&str> = filtered + .lines + .iter() + .filter_map(|line| match line { + HunkLine::Context(text) | HunkLine::Deletion(text) => Some(text.as_str()), + HunkLine::Addition(_) => None, + }) + .collect(); + + let new_segment: Vec = filtered + .lines + .iter() + .filter_map(|line| match line { + HunkLine::Context(text) | HunkLine::Addition(text) => Some(text.clone()), + HunkLine::Deletion(_) => None, + }) + .collect(); + + if local_start + old_len > lines.len() { + return Err(format!( + "patch application exceeds excerpt bounds near source row {}", + filtered.old_start + )); + } + + let current_segment: Vec<&str> = lines[local_start..local_start + old_len] + .iter() + .map(String::as_str) + .collect(); + + if current_segment != old_segment { + let mut details = String::new(); + let _ = write!( + details, + "patch context mismatch near source row {}: expected {:?}, found {:?}", + filtered.old_start, old_segment, current_segment + ); + return Err(details); + } + + lines.splice(local_start..local_start + old_len, new_segment); + line_delta += new_len as i64 - old_len as i64; + } + + Ok(join_lines(&lines, base_has_trailing_newline)) +} + +fn split_preserving_final_empty_line(text: &str) -> Vec { + let mut lines: Vec = text.lines().map(ToString::to_string).collect(); + if text.ends_with('\n') { + if lines.last().is_some_and(|line| !line.is_empty()) || lines.is_empty() { + lines.push(String::new()); + } + } + lines +} + +fn join_lines(lines: &[String], had_trailing_newline: bool) -> String { + if lines.is_empty() { + return String::new(); + } + + let mut joined = lines.join("\n"); + if had_trailing_newline && !joined.ends_with('\n') { + joined.push('\n'); + } + if !had_trailing_newline && joined.ends_with('\n') { + joined.pop(); + } + joined +} + +fn filter_hunk_to_excerpt( + hunk: &ParsedHunk, + excerpt_start_row: u32, + excerpt_end_row: u32, +) -> Option { + let mut filtered_lines = Vec::new(); + let mut current_old_row = hunk.old_start.saturating_sub(1); + let mut filtered_old_start = None; + let mut has_overlap = false; + + for line in &hunk.lines { + match line { + HunkLine::Context(text) => { + let in_excerpt = + current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row; + if in_excerpt { + filtered_old_start.get_or_insert(current_old_row); + filtered_lines.push(HunkLine::Context(text.clone())); + has_overlap = true; + } + current_old_row += 1; + } + HunkLine::Deletion(text) => { + let in_excerpt = + current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row; + if in_excerpt { + filtered_old_start.get_or_insert(current_old_row); + filtered_lines.push(HunkLine::Deletion(text.clone())); + has_overlap = true; + } + current_old_row += 1; + } + HunkLine::Addition(text) => { + let insertion_in_excerpt = + current_old_row >= excerpt_start_row && current_old_row <= excerpt_end_row; + if insertion_in_excerpt { + filtered_old_start.get_or_insert(current_old_row); + filtered_lines.push(HunkLine::Addition(text.clone())); + has_overlap = true; + } + } + } + } + + if !has_overlap { + return None; + } + + Some(ParsedHunk { + old_start: filtered_old_start.unwrap_or(excerpt_start_row), + lines: filtered_lines, + }) +} + +fn parse_diff_hunks(diff: &str) -> Vec { + let mut hunks = Vec::new(); + let mut current_hunk: Option = None; + + for line in diff.lines() { + if let Some((old_start, old_count, _new_start, _new_count)) = parse_hunk_header(line) { + if let Some(hunk) = current_hunk.take() { + hunks.push(hunk); + } + let _ = old_count; + current_hunk = Some(ParsedHunk { + old_start, + lines: Vec::new(), + }); + continue; + } + + let Some(hunk) = current_hunk.as_mut() else { + continue; + }; + + if let Some(text) = line.strip_prefix('+') { + if !line.starts_with("+++") { + hunk.lines.push(HunkLine::Addition(text.to_string())); + } + } else if let Some(text) = line.strip_prefix('-') { + if !line.starts_with("---") { + hunk.lines.push(HunkLine::Deletion(text.to_string())); + } + } else if let Some(text) = line.strip_prefix(' ') { + hunk.lines.push(HunkLine::Context(text.to_string())); + } else if line.is_empty() { + hunk.lines.push(HunkLine::Context(String::new())); + } + } + + if let Some(hunk) = current_hunk { + hunks.push(hunk); + } + + hunks +} + +fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> { + let line = line.strip_prefix("@@ -")?; + let (old_part, rest) = line.split_once(' ')?; + let rest = rest.strip_prefix('+')?; + let (new_part, _) = rest.split_once(" @@")?; + + let (old_start, old_count) = parse_hunk_range(old_part)?; + let (new_start, new_count) = parse_hunk_range(new_part)?; + Some((old_start, old_count, new_start, new_count)) +} + +fn parse_hunk_range(part: &str) -> Option<(u32, u32)> { + if let Some((start, count)) = part.split_once(',') { + Some((start.parse().ok()?, count.parse().ok()?)) + } else { + Some((part.parse().ok()?, 1)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn applies_patch_in_file_mode() { + let base = "fn main() {\n println!(\"hello\");\n}\n"; + let patch = "@@ -1,3 +1,3 @@\n fn main() {\n- println!(\"hello\");\n+ println!(\"world\");\n }\n"; + + let actual = apply_patch_to_excerpt(base, patch, 0).unwrap(); + assert_eq!(actual, "fn main() {\n println!(\"world\");\n}\n"); + } + + #[test] + fn applies_patch_in_json_excerpt_mode() { + let base = "b\nc\nd\n"; + let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n"; + + let actual = apply_patch_to_excerpt(base, patch, 1).unwrap(); + assert_eq!(actual, "x\ny\nd\n"); + } + + #[test] + fn applies_patch_with_excerpt_relative_line_numbers() { + let base = "a\nb\nc\nd\n"; + // Patch uses excerpt-relative line numbers (line 2 of excerpt) + // even though the excerpt starts at file row 100. + let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n"; + + let actual = apply_patch_to_excerpt(base, patch, 100).unwrap(); + assert_eq!(actual, "a\nx\ny\nd\n"); + } + + #[test] + fn prefers_file_global_line_numbers_over_excerpt_relative() { + let base = "a\nb\nc\n"; + // Patch uses file-global line numbers: excerpt starts at row 5, + // hunk targets line 6 (1-based) = row 5 (0-based) = first line. + let patch = "@@ -6,2 +6,2 @@\n-a\n-b\n+x\n+y\n"; + + let actual = apply_patch_to_excerpt(base, patch, 5).unwrap(); + assert_eq!(actual, "x\ny\nc\n"); + } +} diff --git a/crates/edit_prediction_metrics/src/patch_metrics.rs b/crates/edit_prediction_metrics/src/patch_metrics.rs index 9da499796efabc8e1a767dd1b2ed3843b38d06eb..85470da91c59d514d57b6d1080e92a40b695e202 100644 --- a/crates/edit_prediction_metrics/src/patch_metrics.rs +++ b/crates/edit_prediction_metrics/src/patch_metrics.rs @@ -687,6 +687,35 @@ fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec { .collect() } +/// Reconstruct old and new text from a unified diff. +/// +/// Context and deletion lines form the old text; context and addition +/// lines form the new text. Returns `(old_text, new_text)`. +pub fn reconstruct_texts_from_diff(patch_str: &str) -> (String, String) { + let patch = Patch::parse_unified_diff(patch_str); + let mut old_lines: Vec<&str> = Vec::new(); + let mut new_lines: Vec<&str> = Vec::new(); + + for hunk in &patch.hunks { + for line in &hunk.lines { + match line { + PatchLine::Context(content) => { + old_lines.push(content); + new_lines.push(content); + } + PatchLine::Deletion(content) => { + old_lines.push(content); + } + PatchLine::Addition(content) => { + new_lines.push(content); + } + PatchLine::Garbage(_) => {} + } + } + } + + (old_lines.join("\n"), new_lines.join("\n")) +} #[derive(Debug, Default, Clone)] struct Patch { hunks: Vec, diff --git a/crates/edit_prediction_metrics/src/tokenize.rs b/crates/edit_prediction_metrics/src/tokenize.rs index 250a5c15167cbcd00e5ee3fb0397cfed011be5bb..72d5535e48854394a924e87e3c122904db21921c 100644 --- a/crates/edit_prediction_metrics/src/tokenize.rs +++ b/crates/edit_prediction_metrics/src/tokenize.rs @@ -1,33 +1,158 @@ -fn char_class(character: char) -> u8 { - if character.is_alphanumeric() || character == '_' { - 0 +use std::{iter::Peekable, str::CharIndices}; + +#[derive(Clone, Copy, PartialEq, Eq)] +enum CharClass { + Identifier, + Newline, + Whitespace, + Punctuation, +} + +const MULTI_CHAR_PUNCTUATION: &[&str] = &[ + ">>>=", "<<=", ">>=", "...", "..=", "??=", "**=", ">>>", "::", "->", "=>", "==", "!=", "<=", + ">=", "&&", "||", "<<", ">>", "..", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", "++", "--", + "**", "??", "?.", ":=", "<-", "//", "/*", "*/", +]; + +fn char_class(character: char) -> CharClass { + if character == '\n' || character == '\r' { + CharClass::Newline } else if character.is_whitespace() { - 1 + CharClass::Whitespace + } else if character.is_alphanumeric() || character == '_' { + CharClass::Identifier } else { - 2 + CharClass::Punctuation } } +fn is_identifier_boundary(previous: char, current: char, next: Option) -> bool { + (current.is_uppercase() && (previous.is_lowercase() || previous.is_numeric())) + || (current.is_uppercase() + && previous.is_uppercase() + && next.is_some_and(|next| next.is_lowercase())) +} + +fn push_identifier_tokens<'a>(identifier: &'a str, tokens: &mut Vec<&'a str>) { + let characters: Vec<(usize, char)> = identifier.char_indices().collect(); + let mut segment_start = 0; + let mut index = 0; + + while index < characters.len() { + let (byte_index, character) = characters[index]; + + if character == '_' { + if segment_start < byte_index { + tokens.push(&identifier[segment_start..byte_index]); + } + + let mut underscore_end = byte_index + character.len_utf8(); + index += 1; + + while index < characters.len() && characters[index].1 == '_' { + underscore_end = characters[index].0 + characters[index].1.len_utf8(); + index += 1; + } + + tokens.push(&identifier[byte_index..underscore_end]); + segment_start = underscore_end; + continue; + } + + if byte_index > segment_start { + let previous = characters[index - 1].1; + let next = characters.get(index + 1).map(|(_, character)| *character); + + if is_identifier_boundary(previous, character, next) { + tokens.push(&identifier[segment_start..byte_index]); + segment_start = byte_index; + } + } + + index += 1; + } + + if segment_start < identifier.len() { + tokens.push(&identifier[segment_start..]); + } +} + +fn push_punctuation_token<'a>( + text: &'a str, + start: usize, + character: char, + characters: &mut Peekable>, + tokens: &mut Vec<&'a str>, +) { + let remaining = &text[start..]; + + for punctuation in MULTI_CHAR_PUNCTUATION { + if remaining.starts_with(punctuation) { + for _ in punctuation.chars().skip(1) { + characters.next(); + } + + tokens.push(&remaining[..punctuation.len()]); + return; + } + } + + let end = start + character.len_utf8(); + tokens.push(&text[start..end]); +} + pub(crate) fn tokenize(text: &str) -> Vec<&str> { let mut tokens = Vec::new(); let mut characters = text.char_indices().peekable(); while let Some((start, character)) = characters.next() { - let class = char_class(character); - if class == 2 { - tokens.push(&text[start..start + character.len_utf8()]); - continue; - } + match char_class(character) { + CharClass::Identifier => { + let mut end = start + character.len_utf8(); + + while let Some(&(next_start, next_character)) = characters.peek() { + if char_class(next_character) != CharClass::Identifier { + break; + } + + end = next_start + next_character.len_utf8(); + characters.next(); + } + + push_identifier_tokens(&text[start..end], &mut tokens); + } + CharClass::Newline => { + let mut end = start + character.len_utf8(); + + while let Some(&(next_start, next_character)) = characters.peek() { + if char_class(next_character) != CharClass::Newline { + break; + } + + end = next_start + next_character.len_utf8(); + characters.next(); + } - let mut end = start + character.len_utf8(); - while let Some(&(_, next_character)) = characters.peek() { - if char_class(next_character) != class { - break; + tokens.push(&text[start..end]); + } + CharClass::Whitespace => { + let mut end = start + character.len_utf8(); + + while let Some(&(next_start, next_character)) = characters.peek() { + if char_class(next_character) != CharClass::Whitespace { + break; + } + + end = next_start + next_character.len_utf8(); + characters.next(); + } + + tokens.push(&text[start..end]); + } + CharClass::Punctuation => { + push_punctuation_token(text, start, character, &mut characters, &mut tokens); } - end += next_character.len_utf8(); - characters.next(); } - tokens.push(&text[start..end]); } tokens @@ -38,17 +163,58 @@ mod tests { use super::tokenize; #[test] - fn tokenizes_code_like_text() { + fn tokenizes_code() { assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]); assert_eq!( tokenize("foo_bar123 + baz"), - vec!["foo_bar123", " ", "+", " ", "baz"] + vec!["foo", "_", "bar123", " ", "+", " ", "baz"] ); assert_eq!( tokenize("print(\"hello\")"), vec!["print", "(", "\"", "hello", "\"", ")"] ); - assert_eq!(tokenize("hello_world"), vec!["hello_world"]); + assert_eq!(tokenize("hello_world"), vec!["hello", "_", "world"]); assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]); } + + #[test] + fn tokenizes_identifier_case_styles() { + assert_eq!( + tokenize("camelCase PascalCase snake_case"), + vec![ + "camel", "Case", " ", "Pascal", "Case", " ", "snake", "_", "case" + ] + ); + assert_eq!( + tokenize("myHTTPServer __private_value foo__bar"), + vec![ + "my", "HTTP", "Server", " ", "__", "private", "_", "value", " ", "foo", "__", "bar" + ] + ); + assert_eq!( + tokenize("XMLHttpRequest Version2Update"), + vec!["XML", "Http", "Request", " ", "Version2", "Update"] + ); + } + + #[test] + fn tokenizes_grouped_punctuation() { + assert_eq!( + tokenize("a::b -> c != d ..= e"), + vec![ + "a", "::", "b", " ", "->", " ", "c", " ", "!=", " ", "d", " ", "..=", " ", "e" + ] + ); + assert_eq!( + tokenize("foo?.bar ?? baz"), + vec!["foo", "?.", "bar", " ", "??", " ", "baz"] + ); + } + + #[test] + fn tokenize_whitespace_runs() { + assert_eq!(tokenize(" "), vec![" "]); + assert_eq!(tokenize(" \n foo"), vec![" ", "\n", " ", "foo"]); + assert_eq!(tokenize("\r\n\nfoo"), vec!["\r\n\n", "foo"]); + } } diff --git a/typos.toml b/typos.toml index f647c5ac91e1d57e410fe175ce8399348d838a54..f2cd6d18be3f0c134d339c8533254e2619309408 100644 --- a/typos.toml +++ b/typos.toml @@ -63,6 +63,7 @@ extend-exclude = [ "crates/gpui_macos/src/dispatcher.rs", # Tests contain partially incomplete words (by design) "crates/edit_prediction_cli/src/split_commit.rs", + "crates/edit_prediction_metrics/src/kept_rate.rs", # Eval examples contain intentionally partial words (e.g. "secur" for "secure") "crates/edit_prediction_cli/evals/", # Tests contain `baˇr` that cause `"ba" should be "by" or "be".`-like false-positives