diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index b06ac57c54909d690d5aa65b99760586248e2bf9..3ef359f35da0dfa9989c0f8feb400421d2c44a83 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -82,6 +82,7 @@ pub struct ExamplePrediction { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ExampleScore { pub delta_chr_f: f32, + pub braces_disbalance: usize, } impl Example { diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 148c1ec7cf4ad54180aa521269ea2e243875c96a..4382b775e237e31fff17c20dfb7a2bfb1656f2cb 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -317,6 +317,24 @@ fn count_ngrams(text: &str, n: usize) -> Counts { counts } +pub fn braces_disbalance(text: &str) -> usize { + let mut disbalance = 0isize; + + let a = text.chars().filter(|&c| c == '{').count() as isize; + let b = text.chars().filter(|&c| c == '}').count() as isize; + disbalance += (a - b).abs(); + + let a = text.chars().filter(|&c| c == '(').count() as isize; + let b = text.chars().filter(|&c| c == ')').count() as isize; + disbalance += (a - b).abs(); + + let a = text.chars().filter(|&c| c == '[').count() as isize; + let b = text.chars().filter(|&c| c == ']').count() as isize; + disbalance += (a - b).abs(); + + disbalance as usize +} + #[cfg(test)] mod test_optimization { use super::*; @@ -529,4 +547,16 @@ mod test { let score = delta_chr_f(text, text, text); assert!((score - 100.0).abs() < 1e-2); } + + #[test] + fn test_braces_disbalance() { + let text = "let x = { 1 + 2 };"; + assert_eq!(braces_disbalance(text), 0); + + let text = "let x = { 1 + 2"; + assert_eq!(braces_disbalance(text), 1); + + let text = "let x = { 1 + 2 )"; + assert_eq!(braces_disbalance(text), 2); + } } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index cae85fbfa5fb5f9950aa5d3e11b90937634c1ece..eaa42da71883f1069a947ca827cb3f1ef27eb891 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -39,33 +39,27 @@ pub async fn run_scoring( }) .collect::, _>>()?; + let zero_scores = ExampleScore { + delta_chr_f: 0.0, + braces_disbalance: 0, + }; + progress.set_substatus("computing metrics"); let mut scores = vec![]; for prediction in &example.predictions { - let actual_patch = match &prediction.actual_patch { - Some(patch) => patch.clone(), - None => { - if prediction.actual_output.is_empty() { - scores.push(ExampleScore { delta_chr_f: 0.0 }); - continue; - } - match parse_prediction_output( - example, - &prediction.actual_output, - prediction.provider, - ) { - Ok(patch) => patch, - Err(_) => { - scores.push(ExampleScore { delta_chr_f: 0.0 }); - continue; - } - } - } + let actual_patch = prediction.actual_patch.clone().or_else(|| { + parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok() + }); + + let Some(actual_patch) = actual_patch else { + scores.push(zero_scores.clone()); + continue; }; + let actual_text = match apply_diff_to_string(&actual_patch, original_text) { Ok(text) => text, Err(_) => { - scores.push(ExampleScore { delta_chr_f: 0.0 }); + scores.push(zero_scores.clone()); continue; } }; @@ -73,8 +67,24 @@ pub async fn run_scoring( .iter() .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32) .fold(0.0, f32::max); + + let disbalance_before = metrics::braces_disbalance(&original_text); + let disbalance_after = metrics::braces_disbalance(&actual_text); + let braces_disbalance = disbalance_after.saturating_sub(disbalance_before); + if braces_disbalance > 0 { + std::fs::write( + "/tmp/unbalanced-count.before", + disbalance_before.to_string(), + ) + .ok(); + std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok(); + std::fs::write("/tmp/unbalanced-text.before", &original_text).ok(); + std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok(); + } + scores.push(ExampleScore { delta_chr_f: best_delta_chr_f, + braces_disbalance, }); } @@ -86,22 +96,30 @@ pub fn print_report(examples: &[Example]) { eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" ); - eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF"); + eprintln!( + "{:<50} {:>14} {:>10}", + "Example name", "BracesDisbalance", "DeltaChrF" + ); eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" ); let mut all_delta_chr_f_scores = Vec::new(); + let mut braces_disbalance_sum: usize = 0; + let mut total_scores: usize = 0; for example in examples { for score in example.score.iter() { eprintln!( - "{:<50} {:>9.2}", + "{:<50} {:>14} {:>9.2}", truncate_name(&example.spec.name, 50), + score.braces_disbalance, score.delta_chr_f ); all_delta_chr_f_scores.push(score.delta_chr_f); + total_scores += 1; + braces_disbalance_sum += score.braces_disbalance; } } @@ -112,8 +130,13 @@ pub fn print_report(examples: &[Example]) { if !all_delta_chr_f_scores.is_empty() { let avg_delta_chr_f: f32 = all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32; + let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32; + let braces_disbalance_display = format!("{:.2}", braces_disbalance_avg); - eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f); + eprintln!( + "{:<50} {:>14} {:>9.2}", + "AVERAGE", braces_disbalance_display, avg_delta_chr_f + ); eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" );