diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index 6644ecbb5a72b4d7218a9d33bcc1f9f602c3f65d..4f8e984a7de36a96c4e8ad3ac7e5d9e9bfda244b 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -94,11 +94,15 @@ fn write_aggregated_scores( ) -> Result<()> { let mut successful = Vec::new(); let mut failed_count = 0; - writeln!(w, "## Errors\n")?; + for result in all_results.iter().flatten() { match result { Ok(eval_result) => successful.push(eval_result), Err((err, name, repetition_ix)) => { + if failed_count == 0 { + writeln!(w, "## Errors\n")?; + } + failed_count += 1; let err = err .to_string() @@ -114,22 +118,28 @@ fn write_aggregated_scores( } } } - let aggregated_result = EvaluationResult { - context: Scores::aggregate(successful.iter().map(|r| &r.context)), - edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)), - }; - writeln!(w, "\n{}", "-".repeat(80))?; - writeln!(w, "\n## TOTAL SCORES")?; - writeln!(w, "\n### Success Rate")?; - writeln!( - w, - "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉", - successful.len(), - successful.len() + failed_count, - (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0 - )?; - writeln!(w, "{}", aggregated_result)?; + if successful.len() > 1 { + let aggregated_result = EvaluationResult { + context: Scores::aggregate(successful.iter().map(|r| &r.context)), + edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)), + }; + + writeln!(w, "\n{}", "-".repeat(80))?; + writeln!(w, "\n## TOTAL SCORES")?; + writeln!(w, "\n### Success Rate")?; + writeln!(w, "{}", aggregated_result)?; + } + + if successful.len() + failed_count > 1 { + writeln!( + w, + "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉", + successful.len(), + successful.len() + failed_count, + (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0 + )?; + } Ok(()) } @@ -326,7 +336,7 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul let mut false_positive_lines = actual_context_lines.clone(); for entry in &example.expected_context { - let mut best_alternative_score = Scores::default(); + let mut best_alternative_score: Option = None; for alternative in &entry.alternatives { let expected: HashSet<_> = alternative @@ -344,13 +354,17 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul false_positive_lines.retain(|line| !actual_context_lines.contains(line)); - if scores.recall() > best_alternative_score.recall() { - best_alternative_score = scores; + if best_alternative_score + .as_ref() + .is_none_or(|best| scores.recall() > best.recall()) + { + best_alternative_score = Some(scores); } } - eval_result.context.false_negatives += best_alternative_score.false_negatives; - eval_result.context.true_positives += best_alternative_score.true_positives; + let best_alternative = best_alternative_score.unwrap_or_default(); + eval_result.context.false_negatives += best_alternative.false_negatives; + eval_result.context.true_positives += best_alternative.true_positives; } eval_result.context.false_positives = false_positive_lines.len();