diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index a68b68775014ff9a78a68c989fd9b59d176b4dad..8329e394bfa8fe47b9a40f327e7cb38b9839da6c 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -146,7 +146,7 @@ enum Command { /// predicted outputs and removing actual outputs and prompts. Distill, /// Print aggregated scores - Eval(PredictArgs), + Eval(EvalArgs), /// Generate eval examples by analyzing git commits from a repository Synthesize(SynthesizeArgs), /// Remove git repositories and worktrees @@ -180,7 +180,7 @@ impl Display for Command { None => write!(f, "score"), }, Command::Distill => write!(f, "distill"), - Command::Eval(args) => match &args.provider { + Command::Eval(args) => match &args.predict.provider { Some(provider) => write!(f, "eval --provider={}", provider), None => write!(f, "eval"), }, @@ -212,6 +212,15 @@ struct PredictArgs { repetitions: usize, } +#[derive(Debug, Args, Clone)] +struct EvalArgs { + #[clap(flatten)] + predict: PredictArgs, + /// Path to write summary scores as JSON + #[clap(long)] + summary_json: Option, +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] enum PredictionProvider { Sweep, @@ -570,9 +579,12 @@ fn main() { .await?; match &command { - Command::Predict(args) | Command::Score(args) | Command::Eval(args) => { + Command::Predict(args) | Command::Score(args) => { predict::sync_batches(args.provider.as_ref()).await?; } + Command::Eval(args) => { + predict::sync_batches(args.predict.provider.as_ref()).await?; + } _ => (), } @@ -687,10 +699,20 @@ fn main() { Command::Distill => { run_distill(example).await?; } - Command::Score(args) | Command::Eval(args) => { + Command::Score(args) => { run_scoring( example, - &args, + args, + app_state.clone(), + &example_progress, + cx.clone(), + ) + .await?; + } + Command::Eval(args) => { + run_scoring( + example, + &args.predict, app_state.clone(), &example_progress, cx.clone(), @@ -786,14 +808,23 @@ fn main() { Progress::global().finalize(); match &command { - Command::Predict(args) | Command::Score(args) | Command::Eval(args) => { + Command::Predict(args) | Command::Score(args) => { predict::sync_batches(args.provider.as_ref()).await?; } + Command::Eval(args) => { + predict::sync_batches(args.predict.provider.as_ref()).await?; + } _ => (), } match &command { - Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()), + Command::Eval(args) => { + let examples = finished_examples.lock().unwrap(); + score::print_report(&examples); + if let Some(summary_path) = &args.summary_json { + score::write_summary_json(&examples, summary_path)?; + } + } _ => (), }; diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 4c5a8d5d0a97c07691fcfab2b5b7b13631f9a5b8..1b403f50a2590f3e5a5fd2c52bf3c31897745621 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -10,6 +10,10 @@ use crate::{ use anyhow::Context as _; use edit_prediction::udiff::apply_diff_to_string; use gpui::AsyncApp; +use serde::Serialize; +use std::fs::File; +use std::io::BufWriter; +use std::path::Path; use std::sync::Arc; pub async fn run_scoring( @@ -113,12 +117,12 @@ pub fn print_report(examples: &[Example]) { const LINE_WIDTH: usize = 100; let separator = "─".repeat(LINE_WIDTH); - eprintln!("{}", separator); - eprintln!( + println!("{}", separator); + println!( "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}", "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1" ); - eprintln!("{}", separator); + println!("{}", separator); let mut all_delta_chr_f_scores = Vec::new(); let mut braces_disbalance_sum: usize = 0; @@ -133,7 +137,7 @@ pub fn print_report(examples: &[Example]) { false_negatives: score.exact_lines_fn, }; - eprintln!( + println!( "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%", truncate_name(&example.spec.name, 40), score.delta_chr_f, @@ -155,14 +159,14 @@ pub fn print_report(examples: &[Example]) { } } - eprintln!("{}", separator); + println!("{}", separator); if !all_delta_chr_f_scores.is_empty() { let avg_delta_chr_f: f32 = all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32; let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32; - eprintln!( + println!( "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%", "TOTAL / AVERAGE", avg_delta_chr_f, @@ -174,10 +178,10 @@ pub fn print_report(examples: &[Example]) { total_exact_lines.recall() * 100.0, total_exact_lines.f1() * 100.0 ); - eprintln!("{}", separator); + println!("{}", separator); } - eprintln!("\n"); + println!("\n"); } fn truncate_name(name: &str, max_len: usize) -> String { @@ -187,3 +191,71 @@ fn truncate_name(name: &str, max_len: usize) -> String { format!("{}...", &name[..max_len - 3]) } } + +#[derive(Serialize)] +pub struct SummaryJson { + pub total_examples: usize, + pub avg_delta_chr_f: f32, + pub avg_braces_disbalance: f32, + pub exact_lines_true_positives: usize, + pub exact_lines_false_positives: usize, + pub exact_lines_false_negatives: usize, + pub exact_lines_precision: f64, + pub exact_lines_recall: f64, + pub exact_lines_f1: f64, +} + +pub fn compute_summary(examples: &[Example]) -> SummaryJson { + use crate::metrics::ClassificationMetrics; + + let mut all_delta_chr_f_scores = Vec::new(); + let mut braces_disbalance_sum: usize = 0; + let mut total_exact_lines = ClassificationMetrics::default(); + let mut total_scores: usize = 0; + + for example in examples { + for score in example.score.iter() { + all_delta_chr_f_scores.push(score.delta_chr_f); + total_scores += 1; + braces_disbalance_sum += score.braces_disbalance; + total_exact_lines.true_positives += score.exact_lines_tp; + total_exact_lines.false_positives += score.exact_lines_fp; + total_exact_lines.false_negatives += score.exact_lines_fn; + } + } + + let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() { + 0.0 + } else { + all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32 + }; + + let avg_braces_disbalance = if total_scores == 0 { + 0.0 + } else { + braces_disbalance_sum as f32 / total_scores as f32 + }; + + SummaryJson { + total_examples: total_scores, + avg_delta_chr_f, + avg_braces_disbalance, + exact_lines_true_positives: total_exact_lines.true_positives, + exact_lines_false_positives: total_exact_lines.false_positives, + exact_lines_false_negatives: total_exact_lines.false_negatives, + exact_lines_precision: total_exact_lines.precision(), + exact_lines_recall: total_exact_lines.recall(), + exact_lines_f1: total_exact_lines.f1(), + } +} + +pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> { + let summary = compute_summary(examples); + let file = File::create(path) + .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?; + let writer = BufWriter::new(file); + serde_json::to_writer_pretty(writer, &summary) + .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?; + eprintln!("Wrote summary JSON to: {}", path.display()); + Ok(()) +}