zeta2: Add skip-prediction flag to eval CLI (#42872)

Piotr Osiewicz created

Release Notes:

- N/A

Change summary

crates/zeta_cli/src/evaluate.rs | 135 +++++++++++++++++++++--------------
1 file changed, 81 insertions(+), 54 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs 🔗

@@ -31,6 +31,8 @@ pub struct EvaluateArguments {
     cache: CacheMode,
     #[clap(short, long, default_value_t = 1, alias = "repeat")]
     repetitions: u16,
+    #[arg(long)]
+    skip_prediction: bool,
 }
 
 pub async fn run_evaluate(
@@ -66,6 +68,7 @@ pub async fn run_evaluate(
                         zeta,
                         args.prompt_format,
                         args.use_expected_context,
+                        !args.skip_prediction,
                         args.cache,
                         cx,
                     )
@@ -118,9 +121,14 @@ fn write_aggregated_scores(
     }
 
     if successful.len() > 1 {
+        let mut edit_predictions = successful
+            .iter()
+            .filter_map(|r| r.edit_prediction.as_ref())
+            .peekable();
+        let has_edit_predictions = edit_predictions.peek().is_some();
         let aggregated_result = EvaluationResult {
             context: Scores::aggregate(successful.iter().map(|r| &r.context)),
-            edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
+            edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
         };
 
         writeln!(w, "\n{}", "-".repeat(80))?;
@@ -149,6 +157,7 @@ pub async fn run_evaluate_one(
     zeta: Entity<Zeta>,
     prompt_format: PromptFormat,
     use_expected_context: bool,
+    predict: bool,
     cache_mode: CacheMode,
     cx: &mut AsyncApp,
 ) -> Result<EvaluationResult> {
@@ -164,7 +173,7 @@ pub async fn run_evaluate_one(
     )
     .await?;
 
-    let evaluation_result = evaluate(&example.example, &predict_result);
+    let evaluation_result = evaluate(&example.example, &predict_result, predict);
 
     if repetition_ix.is_none() {
         write_eval_result(
@@ -173,6 +182,7 @@ pub async fn run_evaluate_one(
             &evaluation_result,
             &mut std::io::stdout(),
             std::io::stdout().is_terminal(),
+            predict,
         )?;
     }
 
@@ -185,6 +195,7 @@ pub async fn run_evaluate_one(
             &evaluation_result,
             &mut results_file,
             false,
+            predict,
         )
         .log_err();
     }
@@ -198,25 +209,29 @@ fn write_eval_result(
     evaluation_result: &EvaluationResult,
     out: &mut impl Write,
     use_color: bool,
+    predict: bool,
 ) -> Result<()> {
-    writeln!(
-        out,
-        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
-        compare_diffs(
-            &example.example.expected_patch,
-            &predictions.diff,
-            use_color
-        )
-    )?;
-    writeln!(
-        out,
-        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
-        compare_diffs(
-            &predictions.diff,
-            &example.example.expected_patch,
-            use_color
-        )
-    )?;
+    if predict {
+        writeln!(
+            out,
+            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
+            compare_diffs(
+                &example.example.expected_patch,
+                &predictions.diff,
+                use_color
+            )
+        )?;
+        writeln!(
+            out,
+            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
+            compare_diffs(
+                &predictions.diff,
+                &example.example.expected_patch,
+                use_color
+            )
+        )?;
+    }
+
     writeln!(out, "{:#}", evaluation_result)?;
 
     anyhow::Ok(())
@@ -224,7 +239,7 @@ fn write_eval_result(
 
 #[derive(Debug, Default)]
 pub struct EvaluationResult {
-    pub edit_prediction: Scores,
+    pub edit_prediction: Option<Scores>,
     pub context: Scores,
 }
 
@@ -328,13 +343,19 @@ impl EvaluationResult {
             r#"
 ### Context Scores
 {}
-
-### Edit Prediction Scores
-{}
 "#,
             self.context.to_markdown(),
-            self.edit_prediction.to_markdown()
-        )
+        )?;
+        if let Some(prediction) = &self.edit_prediction {
+            write!(
+                f,
+                r#"
+                ### Edit Prediction Scores
+                {}"#,
+                prediction.to_markdown()
+            )?;
+        }
+        Ok(())
     }
 
     fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -357,20 +378,23 @@ impl EvaluationResult {
             self.context.recall() * 100.0,
             self.context.f1_score() * 100.0
         )?;
-        writeln!(
-            f,
-            "Edit Prediction    {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
-            self.edit_prediction.true_positives,
-            self.edit_prediction.false_positives,
-            self.edit_prediction.false_negatives,
-            self.edit_prediction.precision() * 100.0,
-            self.edit_prediction.recall() * 100.0,
-            self.edit_prediction.f1_score() * 100.0
-        )
+        if let Some(edit_prediction) = &self.edit_prediction {
+            writeln!(
+                f,
+                "Edit Prediction    {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+                edit_prediction.true_positives,
+                edit_prediction.false_positives,
+                edit_prediction.false_negatives,
+                edit_prediction.precision() * 100.0,
+                edit_prediction.recall() * 100.0,
+                edit_prediction.f1_score() * 100.0
+            )?;
+        }
+        Ok(())
     }
 }
 
-pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
+pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
     let mut eval_result = EvaluationResult::default();
 
     let actual_context_lines: HashSet<_> = preds
@@ -420,24 +444,27 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
 
     eval_result.context.false_positives = false_positive_lines.len();
 
-    // todo: alternatives for patches
-    let expected_patch_lines = example
-        .expected_patch
-        .lines()
-        .map(DiffLine::parse)
-        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
-        .map(|line| line.to_string())
-        .collect();
-
-    let actual_patch_lines = preds
-        .diff
-        .lines()
-        .map(DiffLine::parse)
-        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
-        .map(|line| line.to_string())
-        .collect();
+    if predict {
+        // todo: alternatives for patches
+        let expected_patch_lines = example
+            .expected_patch
+            .lines()
+            .map(DiffLine::parse)
+            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+            .map(|line| line.to_string())
+            .collect();
+
+        let actual_patch_lines = preds
+            .diff
+            .lines()
+            .map(DiffLine::parse)
+            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+            .map(|line| line.to_string())
+            .collect();
+
+        eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
+    }
 
-    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
     eval_result
 }