zeta eval: Improve output (#42629)

Agus Zubiaga created

Hides the aggregated scores if only one example/repetition ran. It also
fixes an issue with the expected context scoring.

Release Notes:

- N/A

Change summary

crates/zeta_cli/src/evaluate.rs | 56 +++++++++++++++++++++-------------
1 file changed, 35 insertions(+), 21 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs 🔗

@@ -94,11 +94,15 @@ fn write_aggregated_scores(
 ) -> Result<()> {
     let mut successful = Vec::new();
     let mut failed_count = 0;
-    writeln!(w, "## Errors\n")?;
+
     for result in all_results.iter().flatten() {
         match result {
             Ok(eval_result) => successful.push(eval_result),
             Err((err, name, repetition_ix)) => {
+                if failed_count == 0 {
+                    writeln!(w, "## Errors\n")?;
+                }
+
                 failed_count += 1;
                 let err = err
                     .to_string()
@@ -114,22 +118,28 @@ fn write_aggregated_scores(
             }
         }
     }
-    let aggregated_result = EvaluationResult {
-        context: Scores::aggregate(successful.iter().map(|r| &r.context)),
-        edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
-    };
 
-    writeln!(w, "\n{}", "-".repeat(80))?;
-    writeln!(w, "\n## TOTAL SCORES")?;
-    writeln!(w, "\n### Success Rate")?;
-    writeln!(
-        w,
-        "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
-        successful.len(),
-        successful.len() + failed_count,
-        (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
-    )?;
-    writeln!(w, "{}", aggregated_result)?;
+    if successful.len() > 1 {
+        let aggregated_result = EvaluationResult {
+            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
+            edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
+        };
+
+        writeln!(w, "\n{}", "-".repeat(80))?;
+        writeln!(w, "\n## TOTAL SCORES")?;
+        writeln!(w, "\n### Success Rate")?;
+        writeln!(w, "{}", aggregated_result)?;
+    }
+
+    if successful.len() + failed_count > 1 {
+        writeln!(
+            w,
+            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
+            successful.len(),
+            successful.len() + failed_count,
+            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
+        )?;
+    }
 
     Ok(())
 }
@@ -326,7 +336,7 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
     let mut false_positive_lines = actual_context_lines.clone();
 
     for entry in &example.expected_context {
-        let mut best_alternative_score = Scores::default();
+        let mut best_alternative_score: Option<Scores> = None;
 
         for alternative in &entry.alternatives {
             let expected: HashSet<_> = alternative
@@ -344,13 +354,17 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
 
             false_positive_lines.retain(|line| !actual_context_lines.contains(line));
 
-            if scores.recall() > best_alternative_score.recall() {
-                best_alternative_score = scores;
+            if best_alternative_score
+                .as_ref()
+                .is_none_or(|best| scores.recall() > best.recall())
+            {
+                best_alternative_score = Some(scores);
             }
         }
 
-        eval_result.context.false_negatives += best_alternative_score.false_negatives;
-        eval_result.context.true_positives += best_alternative_score.true_positives;
+        let best_alternative = best_alternative_score.unwrap_or_default();
+        eval_result.context.false_negatives += best_alternative.false_negatives;
+        eval_result.context.true_positives += best_alternative.true_positives;
     }
 
     eval_result.context.false_positives = false_positive_lines.len();