From f1bebd79d1e1debf21efc206ddf9c0606ef6cca4 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:37:51 +0100 Subject: [PATCH] zeta2: Add skip-prediction flag to eval CLI (#42872) Release Notes: - N/A --- crates/zeta_cli/src/evaluate.rs | 135 +++++++++++++++++++------------- 1 file changed, 81 insertions(+), 54 deletions(-) diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index b0b3820362889051e3e5c0eef03ef10c7f0d6fa8..a06662c8bf17535900923eb875261f911ded12f7 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -31,6 +31,8 @@ pub struct EvaluateArguments { cache: CacheMode, #[clap(short, long, default_value_t = 1, alias = "repeat")] repetitions: u16, + #[arg(long)] + skip_prediction: bool, } pub async fn run_evaluate( @@ -66,6 +68,7 @@ pub async fn run_evaluate( zeta, args.prompt_format, args.use_expected_context, + !args.skip_prediction, args.cache, cx, ) @@ -118,9 +121,14 @@ fn write_aggregated_scores( } if successful.len() > 1 { + let mut edit_predictions = successful + .iter() + .filter_map(|r| r.edit_prediction.as_ref()) + .peekable(); + let has_edit_predictions = edit_predictions.peek().is_some(); let aggregated_result = EvaluationResult { context: Scores::aggregate(successful.iter().map(|r| &r.context)), - edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)), + edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)), }; writeln!(w, "\n{}", "-".repeat(80))?; @@ -149,6 +157,7 @@ pub async fn run_evaluate_one( zeta: Entity, prompt_format: PromptFormat, use_expected_context: bool, + predict: bool, cache_mode: CacheMode, cx: &mut AsyncApp, ) -> Result { @@ -164,7 +173,7 @@ pub async fn run_evaluate_one( ) .await?; - let evaluation_result = evaluate(&example.example, &predict_result); + let evaluation_result = evaluate(&example.example, &predict_result, predict); if repetition_ix.is_none() { write_eval_result( @@ -173,6 +182,7 @@ pub async fn run_evaluate_one( &evaluation_result, &mut std::io::stdout(), std::io::stdout().is_terminal(), + predict, )?; } @@ -185,6 +195,7 @@ pub async fn run_evaluate_one( &evaluation_result, &mut results_file, false, + predict, ) .log_err(); } @@ -198,25 +209,29 @@ fn write_eval_result( evaluation_result: &EvaluationResult, out: &mut impl Write, use_color: bool, + predict: bool, ) -> Result<()> { - writeln!( - out, - "## Expected edit prediction:\n\n```diff\n{}\n```\n", - compare_diffs( - &example.example.expected_patch, - &predictions.diff, - use_color - ) - )?; - writeln!( - out, - "## Actual edit prediction:\n\n```diff\n{}\n```\n", - compare_diffs( - &predictions.diff, - &example.example.expected_patch, - use_color - ) - )?; + if predict { + writeln!( + out, + "## Expected edit prediction:\n\n```diff\n{}\n```\n", + compare_diffs( + &example.example.expected_patch, + &predictions.diff, + use_color + ) + )?; + writeln!( + out, + "## Actual edit prediction:\n\n```diff\n{}\n```\n", + compare_diffs( + &predictions.diff, + &example.example.expected_patch, + use_color + ) + )?; + } + writeln!(out, "{:#}", evaluation_result)?; anyhow::Ok(()) @@ -224,7 +239,7 @@ fn write_eval_result( #[derive(Debug, Default)] pub struct EvaluationResult { - pub edit_prediction: Scores, + pub edit_prediction: Option, pub context: Scores, } @@ -328,13 +343,19 @@ impl EvaluationResult { r#" ### Context Scores {} - -### Edit Prediction Scores -{} "#, self.context.to_markdown(), - self.edit_prediction.to_markdown() - ) + )?; + if let Some(prediction) = &self.edit_prediction { + write!( + f, + r#" + ### Edit Prediction Scores + {}"#, + prediction.to_markdown() + )?; + } + Ok(()) } fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -357,20 +378,23 @@ impl EvaluationResult { self.context.recall() * 100.0, self.context.f1_score() * 100.0 )?; - writeln!( - f, - "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", - self.edit_prediction.true_positives, - self.edit_prediction.false_positives, - self.edit_prediction.false_negatives, - self.edit_prediction.precision() * 100.0, - self.edit_prediction.recall() * 100.0, - self.edit_prediction.f1_score() * 100.0 - ) + if let Some(edit_prediction) = &self.edit_prediction { + writeln!( + f, + "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", + edit_prediction.true_positives, + edit_prediction.false_positives, + edit_prediction.false_negatives, + edit_prediction.precision() * 100.0, + edit_prediction.recall() * 100.0, + edit_prediction.f1_score() * 100.0 + )?; + } + Ok(()) } } -pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult { +pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult { let mut eval_result = EvaluationResult::default(); let actual_context_lines: HashSet<_> = preds @@ -420,24 +444,27 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul eval_result.context.false_positives = false_positive_lines.len(); - // todo: alternatives for patches - let expected_patch_lines = example - .expected_patch - .lines() - .map(DiffLine::parse) - .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) - .map(|line| line.to_string()) - .collect(); - - let actual_patch_lines = preds - .diff - .lines() - .map(DiffLine::parse) - .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) - .map(|line| line.to_string()) - .collect(); + if predict { + // todo: alternatives for patches + let expected_patch_lines = example + .expected_patch + .lines() + .map(DiffLine::parse) + .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) + .map(|line| line.to_string()) + .collect(); + + let actual_patch_lines = preds + .diff + .lines() + .map(DiffLine::parse) + .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) + .map(|line| line.to_string()) + .collect(); + + eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines)); + } - eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines); eval_result }