ep: Check whether predictions worsen brace balance (#47301)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/example.rs |  1 
crates/edit_prediction_cli/src/metrics.rs | 30 ++++++++++
crates/edit_prediction_cli/src/score.rs   | 69 ++++++++++++++++--------
3 files changed, 77 insertions(+), 23 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/example.rs 🔗

@@ -82,6 +82,7 @@ pub struct ExamplePrediction {
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct ExampleScore {
     pub delta_chr_f: f32,
+    pub braces_disbalance: usize,
 }
 
 impl Example {

crates/edit_prediction_cli/src/metrics.rs 🔗

@@ -317,6 +317,24 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
     counts
 }
 
+pub fn braces_disbalance(text: &str) -> usize {
+    let mut disbalance = 0isize;
+
+    let a = text.chars().filter(|&c| c == '{').count() as isize;
+    let b = text.chars().filter(|&c| c == '}').count() as isize;
+    disbalance += (a - b).abs();
+
+    let a = text.chars().filter(|&c| c == '(').count() as isize;
+    let b = text.chars().filter(|&c| c == ')').count() as isize;
+    disbalance += (a - b).abs();
+
+    let a = text.chars().filter(|&c| c == '[').count() as isize;
+    let b = text.chars().filter(|&c| c == ']').count() as isize;
+    disbalance += (a - b).abs();
+
+    disbalance as usize
+}
+
 #[cfg(test)]
 mod test_optimization {
     use super::*;
@@ -529,4 +547,16 @@ mod test {
         let score = delta_chr_f(text, text, text);
         assert!((score - 100.0).abs() < 1e-2);
     }
+
+    #[test]
+    fn test_braces_disbalance() {
+        let text = "let x = { 1 + 2 };";
+        assert_eq!(braces_disbalance(text), 0);
+
+        let text = "let x = { 1 + 2";
+        assert_eq!(braces_disbalance(text), 1);
+
+        let text = "let x = { 1 + 2 )";
+        assert_eq!(braces_disbalance(text), 2);
+    }
 }

crates/edit_prediction_cli/src/score.rs 🔗

@@ -39,33 +39,27 @@ pub async fn run_scoring(
         })
         .collect::<Result<Vec<_>, _>>()?;
 
+    let zero_scores = ExampleScore {
+        delta_chr_f: 0.0,
+        braces_disbalance: 0,
+    };
+
     progress.set_substatus("computing metrics");
     let mut scores = vec![];
     for prediction in &example.predictions {
-        let actual_patch = match &prediction.actual_patch {
-            Some(patch) => patch.clone(),
-            None => {
-                if prediction.actual_output.is_empty() {
-                    scores.push(ExampleScore { delta_chr_f: 0.0 });
-                    continue;
-                }
-                match parse_prediction_output(
-                    example,
-                    &prediction.actual_output,
-                    prediction.provider,
-                ) {
-                    Ok(patch) => patch,
-                    Err(_) => {
-                        scores.push(ExampleScore { delta_chr_f: 0.0 });
-                        continue;
-                    }
-                }
-            }
+        let actual_patch = prediction.actual_patch.clone().or_else(|| {
+            parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
+        });
+
+        let Some(actual_patch) = actual_patch else {
+            scores.push(zero_scores.clone());
+            continue;
         };
+
         let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
             Ok(text) => text,
             Err(_) => {
-                scores.push(ExampleScore { delta_chr_f: 0.0 });
+                scores.push(zero_scores.clone());
                 continue;
             }
         };
@@ -73,8 +67,24 @@ pub async fn run_scoring(
             .iter()
             .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
             .fold(0.0, f32::max);
+
+        let disbalance_before = metrics::braces_disbalance(&original_text);
+        let disbalance_after = metrics::braces_disbalance(&actual_text);
+        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
+        if braces_disbalance > 0 {
+            std::fs::write(
+                "/tmp/unbalanced-count.before",
+                disbalance_before.to_string(),
+            )
+            .ok();
+            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
+            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
+            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
+        }
+
         scores.push(ExampleScore {
             delta_chr_f: best_delta_chr_f,
+            braces_disbalance,
         });
     }
 
@@ -86,22 +96,30 @@ pub fn print_report(examples: &[Example]) {
     eprintln!(
         "──────────────────────────────────────────────────────────────────────────────────────"
     );
-    eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
+    eprintln!(
+        "{:<50} {:>14} {:>10}",
+        "Example name", "BracesDisbalance", "DeltaChrF"
+    );
     eprintln!(
         "──────────────────────────────────────────────────────────────────────────────────────"
     );
 
     let mut all_delta_chr_f_scores = Vec::new();
+    let mut braces_disbalance_sum: usize = 0;
+    let mut total_scores: usize = 0;
 
     for example in examples {
         for score in example.score.iter() {
             eprintln!(
-                "{:<50} {:>9.2}",
+                "{:<50} {:>14} {:>9.2}",
                 truncate_name(&example.spec.name, 50),
+                score.braces_disbalance,
                 score.delta_chr_f
             );
 
             all_delta_chr_f_scores.push(score.delta_chr_f);
+            total_scores += 1;
+            braces_disbalance_sum += score.braces_disbalance;
         }
     }
 
@@ -112,8 +130,13 @@ pub fn print_report(examples: &[Example]) {
     if !all_delta_chr_f_scores.is_empty() {
         let avg_delta_chr_f: f32 =
             all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
+        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
+        let braces_disbalance_display = format!("{:.2}", braces_disbalance_avg);
 
-        eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
+        eprintln!(
+            "{:<50} {:>14} {:>9.2}",
+            "AVERAGE", braces_disbalance_display, avg_delta_chr_f
+        );
         eprintln!(
             "──────────────────────────────────────────────────────────────────────────────────────"
         );