@@ -146,7 +146,7 @@ enum Command {
/// predicted outputs and removing actual outputs and prompts.
Distill,
/// Print aggregated scores
- Eval(PredictArgs),
+ Eval(EvalArgs),
/// Generate eval examples by analyzing git commits from a repository
Synthesize(SynthesizeArgs),
/// Remove git repositories and worktrees
@@ -180,7 +180,7 @@ impl Display for Command {
None => write!(f, "score"),
},
Command::Distill => write!(f, "distill"),
- Command::Eval(args) => match &args.provider {
+ Command::Eval(args) => match &args.predict.provider {
Some(provider) => write!(f, "eval --provider={}", provider),
None => write!(f, "eval"),
},
@@ -212,6 +212,15 @@ struct PredictArgs {
repetitions: usize,
}
+#[derive(Debug, Args, Clone)]
+struct EvalArgs {
+ #[clap(flatten)]
+ predict: PredictArgs,
+ /// Path to write summary scores as JSON
+ #[clap(long)]
+ summary_json: Option<PathBuf>,
+}
+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum PredictionProvider {
Sweep,
@@ -570,9 +579,12 @@ fn main() {
.await?;
match &command {
- Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
+ Command::Predict(args) | Command::Score(args) => {
predict::sync_batches(args.provider.as_ref()).await?;
}
+ Command::Eval(args) => {
+ predict::sync_batches(args.predict.provider.as_ref()).await?;
+ }
_ => (),
}
@@ -687,10 +699,20 @@ fn main() {
Command::Distill => {
run_distill(example).await?;
}
- Command::Score(args) | Command::Eval(args) => {
+ Command::Score(args) => {
run_scoring(
example,
- &args,
+ args,
+ app_state.clone(),
+ &example_progress,
+ cx.clone(),
+ )
+ .await?;
+ }
+ Command::Eval(args) => {
+ run_scoring(
+ example,
+ &args.predict,
app_state.clone(),
&example_progress,
cx.clone(),
@@ -786,14 +808,23 @@ fn main() {
Progress::global().finalize();
match &command {
- Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
+ Command::Predict(args) | Command::Score(args) => {
predict::sync_batches(args.provider.as_ref()).await?;
}
+ Command::Eval(args) => {
+ predict::sync_batches(args.predict.provider.as_ref()).await?;
+ }
_ => (),
}
match &command {
- Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
+ Command::Eval(args) => {
+ let examples = finished_examples.lock().unwrap();
+ score::print_report(&examples);
+ if let Some(summary_path) = &args.summary_json {
+ score::write_summary_json(&examples, summary_path)?;
+ }
+ }
_ => (),
};
@@ -10,6 +10,10 @@ use crate::{
use anyhow::Context as _;
use edit_prediction::udiff::apply_diff_to_string;
use gpui::AsyncApp;
+use serde::Serialize;
+use std::fs::File;
+use std::io::BufWriter;
+use std::path::Path;
use std::sync::Arc;
pub async fn run_scoring(
@@ -113,12 +117,12 @@ pub fn print_report(examples: &[Example]) {
const LINE_WIDTH: usize = 100;
let separator = "─".repeat(LINE_WIDTH);
- eprintln!("{}", separator);
- eprintln!(
+ println!("{}", separator);
+ println!(
"{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}",
"Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1"
);
- eprintln!("{}", separator);
+ println!("{}", separator);
let mut all_delta_chr_f_scores = Vec::new();
let mut braces_disbalance_sum: usize = 0;
@@ -133,7 +137,7 @@ pub fn print_report(examples: &[Example]) {
false_negatives: score.exact_lines_fn,
};
- eprintln!(
+ println!(
"{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
truncate_name(&example.spec.name, 40),
score.delta_chr_f,
@@ -155,14 +159,14 @@ pub fn print_report(examples: &[Example]) {
}
}
- eprintln!("{}", separator);
+ println!("{}", separator);
if !all_delta_chr_f_scores.is_empty() {
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
- eprintln!(
+ println!(
"{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
"TOTAL / AVERAGE",
avg_delta_chr_f,
@@ -174,10 +178,10 @@ pub fn print_report(examples: &[Example]) {
total_exact_lines.recall() * 100.0,
total_exact_lines.f1() * 100.0
);
- eprintln!("{}", separator);
+ println!("{}", separator);
}
- eprintln!("\n");
+ println!("\n");
}
fn truncate_name(name: &str, max_len: usize) -> String {
@@ -187,3 +191,71 @@ fn truncate_name(name: &str, max_len: usize) -> String {
format!("{}...", &name[..max_len - 3])
}
}
+
+#[derive(Serialize)]
+pub struct SummaryJson {
+ pub total_examples: usize,
+ pub avg_delta_chr_f: f32,
+ pub avg_braces_disbalance: f32,
+ pub exact_lines_true_positives: usize,
+ pub exact_lines_false_positives: usize,
+ pub exact_lines_false_negatives: usize,
+ pub exact_lines_precision: f64,
+ pub exact_lines_recall: f64,
+ pub exact_lines_f1: f64,
+}
+
+pub fn compute_summary(examples: &[Example]) -> SummaryJson {
+ use crate::metrics::ClassificationMetrics;
+
+ let mut all_delta_chr_f_scores = Vec::new();
+ let mut braces_disbalance_sum: usize = 0;
+ let mut total_exact_lines = ClassificationMetrics::default();
+ let mut total_scores: usize = 0;
+
+ for example in examples {
+ for score in example.score.iter() {
+ all_delta_chr_f_scores.push(score.delta_chr_f);
+ total_scores += 1;
+ braces_disbalance_sum += score.braces_disbalance;
+ total_exact_lines.true_positives += score.exact_lines_tp;
+ total_exact_lines.false_positives += score.exact_lines_fp;
+ total_exact_lines.false_negatives += score.exact_lines_fn;
+ }
+ }
+
+ let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
+ 0.0
+ } else {
+ all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
+ };
+
+ let avg_braces_disbalance = if total_scores == 0 {
+ 0.0
+ } else {
+ braces_disbalance_sum as f32 / total_scores as f32
+ };
+
+ SummaryJson {
+ total_examples: total_scores,
+ avg_delta_chr_f,
+ avg_braces_disbalance,
+ exact_lines_true_positives: total_exact_lines.true_positives,
+ exact_lines_false_positives: total_exact_lines.false_positives,
+ exact_lines_false_negatives: total_exact_lines.false_negatives,
+ exact_lines_precision: total_exact_lines.precision(),
+ exact_lines_recall: total_exact_lines.recall(),
+ exact_lines_f1: total_exact_lines.f1(),
+ }
+}
+
+pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
+ let summary = compute_summary(examples);
+ let file = File::create(path)
+ .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
+ let writer = BufWriter::new(file);
+ serde_json::to_writer_pretty(writer, &summary)
+ .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
+ eprintln!("Wrote summary JSON to: {}", path.display());
+ Ok(())
+}