ep: Option to save evaluation report in JSON (#47409)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/main.rs  | 45 +++++++++++--
crates/edit_prediction_cli/src/score.rs | 88 ++++++++++++++++++++++++--
2 files changed, 118 insertions(+), 15 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -146,7 +146,7 @@ enum Command {
     /// predicted outputs and removing actual outputs and prompts.
     Distill,
     /// Print aggregated scores
-    Eval(PredictArgs),
+    Eval(EvalArgs),
     /// Generate eval examples by analyzing git commits from a repository
     Synthesize(SynthesizeArgs),
     /// Remove git repositories and worktrees
@@ -180,7 +180,7 @@ impl Display for Command {
                 None => write!(f, "score"),
             },
             Command::Distill => write!(f, "distill"),
-            Command::Eval(args) => match &args.provider {
+            Command::Eval(args) => match &args.predict.provider {
                 Some(provider) => write!(f, "eval --provider={}", provider),
                 None => write!(f, "eval"),
             },
@@ -212,6 +212,15 @@ struct PredictArgs {
     repetitions: usize,
 }
 
+#[derive(Debug, Args, Clone)]
+struct EvalArgs {
+    #[clap(flatten)]
+    predict: PredictArgs,
+    /// Path to write summary scores as JSON
+    #[clap(long)]
+    summary_json: Option<PathBuf>,
+}
+
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 enum PredictionProvider {
     Sweep,
@@ -570,9 +579,12 @@ fn main() {
                 .await?;
 
                 match &command {
-                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
+                    Command::Predict(args) | Command::Score(args) => {
                         predict::sync_batches(args.provider.as_ref()).await?;
                     }
+                    Command::Eval(args) => {
+                        predict::sync_batches(args.predict.provider.as_ref()).await?;
+                    }
                     _ => (),
                 }
 
@@ -687,10 +699,20 @@ fn main() {
                                         Command::Distill => {
                                             run_distill(example).await?;
                                         }
-                                        Command::Score(args) | Command::Eval(args) => {
+                                        Command::Score(args) => {
                                             run_scoring(
                                                 example,
-                                                &args,
+                                                args,
+                                                app_state.clone(),
+                                                &example_progress,
+                                                cx.clone(),
+                                            )
+                                            .await?;
+                                        }
+                                        Command::Eval(args) => {
+                                            run_scoring(
+                                                example,
+                                                &args.predict,
                                                 app_state.clone(),
                                                 &example_progress,
                                                 cx.clone(),
@@ -786,14 +808,23 @@ fn main() {
                 Progress::global().finalize();
 
                 match &command {
-                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
+                    Command::Predict(args) | Command::Score(args) => {
                         predict::sync_batches(args.provider.as_ref()).await?;
                     }
+                    Command::Eval(args) => {
+                        predict::sync_batches(args.predict.provider.as_ref()).await?;
+                    }
                     _ => (),
                 }
 
                 match &command {
-                    Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
+                    Command::Eval(args) => {
+                        let examples = finished_examples.lock().unwrap();
+                        score::print_report(&examples);
+                        if let Some(summary_path) = &args.summary_json {
+                            score::write_summary_json(&examples, summary_path)?;
+                        }
+                    }
                     _ => (),
                 };
 

crates/edit_prediction_cli/src/score.rs 🔗

@@ -10,6 +10,10 @@ use crate::{
 use anyhow::Context as _;
 use edit_prediction::udiff::apply_diff_to_string;
 use gpui::AsyncApp;
+use serde::Serialize;
+use std::fs::File;
+use std::io::BufWriter;
+use std::path::Path;
 use std::sync::Arc;
 
 pub async fn run_scoring(
@@ -113,12 +117,12 @@ pub fn print_report(examples: &[Example]) {
     const LINE_WIDTH: usize = 100;
     let separator = "─".repeat(LINE_WIDTH);
 
-    eprintln!("{}", separator);
-    eprintln!(
+    println!("{}", separator);
+    println!(
         "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}",
         "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1"
     );
-    eprintln!("{}", separator);
+    println!("{}", separator);
 
     let mut all_delta_chr_f_scores = Vec::new();
     let mut braces_disbalance_sum: usize = 0;
@@ -133,7 +137,7 @@ pub fn print_report(examples: &[Example]) {
                 false_negatives: score.exact_lines_fn,
             };
 
-            eprintln!(
+            println!(
                 "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
                 truncate_name(&example.spec.name, 40),
                 score.delta_chr_f,
@@ -155,14 +159,14 @@ pub fn print_report(examples: &[Example]) {
         }
     }
 
-    eprintln!("{}", separator);
+    println!("{}", separator);
 
     if !all_delta_chr_f_scores.is_empty() {
         let avg_delta_chr_f: f32 =
             all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
         let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
 
-        eprintln!(
+        println!(
             "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
             "TOTAL / AVERAGE",
             avg_delta_chr_f,
@@ -174,10 +178,10 @@ pub fn print_report(examples: &[Example]) {
             total_exact_lines.recall() * 100.0,
             total_exact_lines.f1() * 100.0
         );
-        eprintln!("{}", separator);
+        println!("{}", separator);
     }
 
-    eprintln!("\n");
+    println!("\n");
 }
 
 fn truncate_name(name: &str, max_len: usize) -> String {
@@ -187,3 +191,71 @@ fn truncate_name(name: &str, max_len: usize) -> String {
         format!("{}...", &name[..max_len - 3])
     }
 }
+
+#[derive(Serialize)]
+pub struct SummaryJson {
+    pub total_examples: usize,
+    pub avg_delta_chr_f: f32,
+    pub avg_braces_disbalance: f32,
+    pub exact_lines_true_positives: usize,
+    pub exact_lines_false_positives: usize,
+    pub exact_lines_false_negatives: usize,
+    pub exact_lines_precision: f64,
+    pub exact_lines_recall: f64,
+    pub exact_lines_f1: f64,
+}
+
+pub fn compute_summary(examples: &[Example]) -> SummaryJson {
+    use crate::metrics::ClassificationMetrics;
+
+    let mut all_delta_chr_f_scores = Vec::new();
+    let mut braces_disbalance_sum: usize = 0;
+    let mut total_exact_lines = ClassificationMetrics::default();
+    let mut total_scores: usize = 0;
+
+    for example in examples {
+        for score in example.score.iter() {
+            all_delta_chr_f_scores.push(score.delta_chr_f);
+            total_scores += 1;
+            braces_disbalance_sum += score.braces_disbalance;
+            total_exact_lines.true_positives += score.exact_lines_tp;
+            total_exact_lines.false_positives += score.exact_lines_fp;
+            total_exact_lines.false_negatives += score.exact_lines_fn;
+        }
+    }
+
+    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
+        0.0
+    } else {
+        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
+    };
+
+    let avg_braces_disbalance = if total_scores == 0 {
+        0.0
+    } else {
+        braces_disbalance_sum as f32 / total_scores as f32
+    };
+
+    SummaryJson {
+        total_examples: total_scores,
+        avg_delta_chr_f,
+        avg_braces_disbalance,
+        exact_lines_true_positives: total_exact_lines.true_positives,
+        exact_lines_false_positives: total_exact_lines.false_positives,
+        exact_lines_false_negatives: total_exact_lines.false_negatives,
+        exact_lines_precision: total_exact_lines.precision(),
+        exact_lines_recall: total_exact_lines.recall(),
+        exact_lines_f1: total_exact_lines.f1(),
+    }
+}
+
+pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
+    let summary = compute_summary(examples);
+    let file = File::create(path)
+        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
+    let writer = BufWriter::new(file);
+    serde_json::to_writer_pretty(writer, &summary)
+        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
+    eprintln!("Wrote summary JSON to: {}", path.display());
+    Ok(())
+}