diff --git a/Cargo.lock b/Cargo.lock index 6a8ba7c667624e0ef434370428cda484f171d494..2c0df1c10f99635793a76b42fdea6ee95b739145 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2090,7 +2090,7 @@ dependencies = [ "bitflags 2.9.4", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "log", "prettyplease", "proc-macro2", @@ -2110,7 +2110,7 @@ dependencies = [ "bitflags 2.9.4", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "regex", @@ -12965,7 +12965,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes 1.11.1", "heck 0.5.0", - "itertools 0.10.5", + "itertools 0.12.1", "log", "multimap 0.10.1", "once_cell", @@ -12998,7 +12998,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.106", diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index a491b48e242a1a648ed4535a1397122ad9674183..42f67e91951d733d01af9c24d2682af5f663319e 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -116,6 +116,7 @@ pub struct ExampleScore { pub cursor_distance: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub cursor_exact_match: Option, + pub wrong_editable_region: Option, } impl Example { diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 6e4f4dc09b04a7c9cfdc15ec92fc9e927370c989..e6f1d49e19742a2a9acb70397a460a98c13b16c2 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,5 +1,7 @@ use collections::HashMap; +use crate::reorder_patch::{Patch, PatchLine}; + pub type Counts = HashMap; type CountsDelta = HashMap; @@ -382,6 +384,32 @@ pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> Classifica ClassificationMetrics::from_counts(&expected_lines, &actual_lines) } +/// A simple proxy for whether the prediction respects editable region. +pub fn is_editable_region_correct(actual_patch: &str) -> bool { + // A typical sign of a wrong editable region: a bunch of lines deletion + // at the beginning or end of the patch. + let patch = Patch::parse_unified_diff(actual_patch); + if patch.hunks.is_empty() { + return true; + } + + let hunk = &patch.hunks[0]; + let mut deletions_at_start = 0; + + for line in hunk.lines.iter() { + match line { + PatchLine::Deletion(_) => deletions_at_start += 1, + _ => break, + } + } + + if deletions_at_start >= 3 { + return false; + } + + true +} + #[cfg(test)] mod test_optimization { use super::*; @@ -530,6 +558,7 @@ mod test_optimization { #[cfg(test)] mod test { use super::*; + use indoc::indoc; #[test] fn test_delta_chr_f_perfect_match() { @@ -726,4 +755,23 @@ index abc123..def456 100644 assert_eq!(metrics.false_positives, 0); assert_eq!(metrics.false_negatives, 0); } + + #[test] + fn test_is_editable_region_correct() { + let patch = indoc! {" + @@ -1,1 +1,1 @@ + -context + -removed + -from the beginning of the file + import sys + +sys.exit(0) + + "}; + assert!(!is_editable_region_correct(patch)); + + let patch = indoc! {" + @@ -1,1 +1,1 @@ + "}; + assert!(is_editable_region_correct(patch)); + } } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 772f518b157a717f87ca9d5b704104fd4dde7181..78a13095f6c5e1dd4dfff74c727c099aeb6cc598 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -74,6 +74,7 @@ pub async fn run_scoring( reversal_ratio: 0.0, cursor_distance: None, cursor_exact_match: None, + wrong_editable_region: None, }; let prompt_inputs = example.prompt_inputs.as_ref().unwrap(); @@ -140,16 +141,6 @@ pub async fn run_scoring( let disbalance_before = metrics::braces_disbalance(&original_text); let disbalance_after = metrics::braces_disbalance(&actual_text); let braces_disbalance = disbalance_after.saturating_sub(disbalance_before); - if braces_disbalance > 0 { - std::fs::write( - "/tmp/unbalanced-count.before", - disbalance_before.to_string(), - ) - .ok(); - std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok(); - std::fs::write("/tmp/unbalanced-text.before", &original_text).ok(); - std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok(); - } // Compute exact lines match against best matching expected patch let best_exact_lines = expected_patches_with_cursors @@ -169,6 +160,9 @@ pub async fn run_scoring( let (cursor_distance, cursor_exact_match) = compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor_offset); + // Compute approximation of editable region correctness + let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch)); + scores.push(ExampleScore { delta_chr_f: best_delta_chr_f, braces_disbalance, @@ -178,6 +172,7 @@ pub async fn run_scoring( reversal_ratio, cursor_distance, cursor_exact_match, + wrong_editable_region, }); } @@ -209,13 +204,13 @@ fn compute_cursor_metrics( pub fn print_report(examples: &[Example]) { use crate::metrics::ClassificationMetrics; - const LINE_WIDTH: usize = 94; + const LINE_WIDTH: usize = 101; let separator = "─".repeat(LINE_WIDTH); println!("{}", separator); println!( - "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6}", - "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor" + "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}", + "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER" ); println!("{}", separator); @@ -232,6 +227,8 @@ pub fn print_report(examples: &[Example]) { let mut cursor_total: usize = 0; let mut cursor_distance_sum: usize = 0; let mut cursor_distance_count: usize = 0; + let mut wrong_editable_region_count: usize = 0; + let mut wrong_editable_region_total: usize = 0; for example in examples { for (score_idx, score) in example.score.iter().enumerate() { @@ -252,6 +249,13 @@ pub fn print_report(examples: &[Example]) { .map(|v| format!("{}", v)) .unwrap_or("-".to_string()); + // Format wrong editable region metric + let wrong_er_str = match score.wrong_editable_region { + Some(true) => "✗", + Some(false) => "", + None => "", + }; + // Format cursor metric let cursor_str = match (score.cursor_exact_match, score.cursor_distance) { (Some(true), _) => "✓".to_string(), @@ -261,7 +265,7 @@ pub fn print_report(examples: &[Example]) { }; println!( - "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}", + "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}", truncate_name(&example.spec.name, 40), score.delta_chr_f, score.braces_disbalance, @@ -269,7 +273,8 @@ pub fn print_report(examples: &[Example]) { score.reversal_ratio * 100.0, qa_reverts_str, qa_conf_str, - cursor_str + cursor_str, + wrong_er_str ); all_delta_chr_f_scores.push(score.delta_chr_f); @@ -294,6 +299,14 @@ pub fn print_report(examples: &[Example]) { } } + // Accumulate wrong editable region metrics + if let Some(wrong) = score.wrong_editable_region { + wrong_editable_region_total += 1; + if wrong { + wrong_editable_region_count += 1; + } + } + // Accumulate cursor metrics if let Some(exact_match) = score.cursor_exact_match { cursor_total += 1; @@ -341,6 +354,14 @@ pub fn print_report(examples: &[Example]) { } else { "-".to_string() }; + let wrong_er_str = if wrong_editable_region_total > 0 { + format!( + "{:.2}%", + wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0 + ) + } else { + "-".to_string() + }; let avg_cursor_distance = if cursor_distance_count > 0 { Some(cursor_distance_sum as f32 / cursor_distance_count as f32) } else { @@ -348,7 +369,7 @@ pub fn print_report(examples: &[Example]) { }; println!( - "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}", + "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}", "TOTAL / AVERAGE", avg_delta_chr_f, braces_disbalance_avg, @@ -356,7 +377,8 @@ pub fn print_report(examples: &[Example]) { avg_reversal_ratio * 100.0, qa_reverts_str, qa_conf_str, - cursor_str + cursor_str, + wrong_er_str ); println!("{}", separator); @@ -405,6 +427,8 @@ pub struct SummaryJson { pub cursor_avg_distance: Option, #[serde(skip_serializing_if = "Option::is_none")] pub cursor_total_evaluated: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub wrong_editable_region_rate: Option, } pub fn compute_summary(examples: &[Example]) -> SummaryJson { @@ -423,6 +447,8 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut cursor_total: usize = 0; let mut cursor_distance_sum: usize = 0; let mut cursor_distance_count: usize = 0; + let mut wrong_editable_region_count: usize = 0; + let mut wrong_editable_region_total: usize = 0; for example in examples { for (score_idx, score) in example.score.iter().enumerate() { @@ -448,6 +474,14 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { } } + // Accumulate wrong editable region metrics + if let Some(wrong) = score.wrong_editable_region { + wrong_editable_region_total += 1; + if wrong { + wrong_editable_region_count += 1; + } + } + // Accumulate cursor metrics if let Some(exact_match) = score.cursor_exact_match { cursor_total += 1; @@ -510,6 +544,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { None }; + let wrong_editable_region_rate = if wrong_editable_region_total > 0 { + Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32) + } else { + None + }; + SummaryJson { total_examples: total_scores, avg_delta_chr_f, @@ -526,6 +566,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { cursor_exact_match_rate, cursor_avg_distance, cursor_total_evaluated, + wrong_editable_region_rate, } }