zeta2: Output `bucketed_analysis.md` (#42890)

Ben Kunkle created

Closes #ISSUE

Makes it so that a file named `bucketed_analysis.md` is written to the
runs directory after an eval is ran with > 1 repetitions. This file
buckets the predictions made by the model by comparing the edits made so
that seeing how many times different failure modes were encountered
becomes much easier.

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/zeta_cli/src/evaluate.rs | 283 +++++++++++++++++++++++++++++++++-
1 file changed, 268 insertions(+), 15 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs 🔗

@@ -1,4 +1,5 @@
 use std::{
+    collections::HashMap,
     io::{IsTerminal, Write},
     path::PathBuf,
     sync::Arc,
@@ -35,6 +36,13 @@ pub struct EvaluateArguments {
     skip_prediction: bool,
 }
 
+#[derive(Debug)]
+pub(crate) struct ExecutionData {
+    execution_id: String,
+    diff: String,
+    reasoning: String,
+}
+
 pub async fn run_evaluate(
     args: EvaluateArguments,
     app_state: &Arc<ZetaCliAppState>,
@@ -87,35 +95,35 @@ pub async fn run_evaluate(
     {
         write_aggregated_scores(&mut output_file, &all_results).log_err();
     };
+
+    if args.repetitions > 1 {
+        if let Err(e) = write_bucketed_analysis(&all_results) {
+            eprintln!("Failed to write bucketed analysis: {:?}", e);
+        }
+    }
+
     print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 }
 
 fn write_aggregated_scores(
     w: &mut impl std::io::Write,
-    all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
+    all_results: &Vec<
+        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
+    >,
 ) -> Result<()> {
     let mut successful = Vec::new();
     let mut failed_count = 0;
 
     for result in all_results.iter().flatten() {
         match result {
-            Ok(eval_result) => successful.push(eval_result),
+            Ok((eval_result, _execution_data)) => successful.push(eval_result),
             Err((err, name, repetition_ix)) => {
                 if failed_count == 0 {
                     writeln!(w, "## Errors\n")?;
                 }
 
                 failed_count += 1;
-                let err = format!("{err:?}")
-                    .replace("<edits", "```xml\n<edits")
-                    .replace("</edits>", "</edits>\n```");
-                writeln!(
-                    w,
-                    "### ERROR {name}{}\n\n{err}\n",
-                    repetition_ix
-                        .map(|ix| format!(" [RUN {ix:03}]"))
-                        .unwrap_or_default()
-                )?;
+                writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
             }
         }
     }
@@ -136,7 +144,6 @@ fn write_aggregated_scores(
 
         writeln!(w, "\n{}", "-".repeat(80))?;
         writeln!(w, "\n## TOTAL SCORES")?;
-        writeln!(w, "\n### Success Rate")?;
         writeln!(w, "{:#}", aggregated_result)?;
     }
 
@@ -163,7 +170,7 @@ pub async fn run_evaluate_one(
     predict: bool,
     cache_mode: CacheMode,
     cx: &mut AsyncApp,
-) -> Result<EvaluationResult> {
+) -> Result<(EvaluationResult, ExecutionData)> {
     let predict_result = zeta2_predict(
         example.clone(),
         project,
@@ -203,7 +210,22 @@ pub async fn run_evaluate_one(
         .log_err();
     }
 
-    anyhow::Ok(evaluation_result)
+    let execution_data = ExecutionData {
+        execution_id: if let Some(rep_ix) = repetition_ix {
+            format!("{:03}", rep_ix)
+        } else {
+            example.name.clone()
+        },
+        diff: predict_result.diff.clone(),
+        reasoning: std::fs::read_to_string(
+            predict_result
+                .run_example_dir
+                .join("prediction_response.md"),
+        )
+        .unwrap_or_default(),
+    };
+
+    anyhow::Ok((evaluation_result, execution_data))
 }
 
 fn write_eval_result(
@@ -507,3 +529,234 @@ pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
 
     annotated.join("\n")
 }
+
+fn write_bucketed_analysis(
+    all_results: &Vec<
+        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
+    >,
+) -> Result<()> {
+    #[derive(Debug)]
+    struct EditBucket {
+        diff: String,
+        is_correct: bool,
+        execution_indices: Vec<String>,
+        reasoning_samples: Vec<String>,
+    }
+
+    let mut total_executions = 0;
+    let mut empty_predictions = Vec::new();
+    let mut errors = Vec::new();
+
+    let mut buckets: HashMap<String, EditBucket> = HashMap::new();
+
+    for result in all_results.iter().flatten() {
+        total_executions += 1;
+
+        let (evaluation_result, execution_data) = match result {
+            Ok((eval_result, execution_data)) => {
+                if execution_data.diff.is_empty() {
+                    empty_predictions.push(execution_data);
+                    continue;
+                }
+                (eval_result, execution_data)
+            }
+            Err(err) => {
+                errors.push(err);
+                continue;
+            }
+        };
+
+        buckets
+            .entry(execution_data.diff.clone())
+            .and_modify(|bucket| {
+                bucket
+                    .execution_indices
+                    .push(execution_data.execution_id.clone());
+                bucket
+                    .reasoning_samples
+                    .push(execution_data.reasoning.clone());
+            })
+            .or_insert_with(|| EditBucket {
+                diff: execution_data.diff.clone(),
+                is_correct: {
+                    evaluation_result
+                        .edit_prediction
+                        .as_ref()
+                        .map_or(false, |edit_prediction| {
+                            edit_prediction.false_positives == 0
+                                && edit_prediction.false_negatives == 0
+                                && edit_prediction.true_positives > 0
+                        })
+                },
+                execution_indices: vec![execution_data.execution_id.clone()],
+                reasoning_samples: vec![execution_data.reasoning.clone()],
+            });
+    }
+
+    let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
+    sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
+        (true, false) => std::cmp::Ordering::Less,
+        (false, true) => std::cmp::Ordering::Greater,
+        _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
+    });
+
+    let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
+    let mut output = std::fs::File::create(&output_path)?;
+
+    writeln!(output, "# Bucketed Edit Analysis\n")?;
+
+    writeln!(output, "## Summary\n")?;
+    writeln!(output, "- **Total executions**: {}", total_executions)?;
+
+    let correct_count: usize = sorted_buckets
+        .iter()
+        .filter(|b| b.is_correct)
+        .map(|b| b.execution_indices.len())
+        .sum();
+
+    let incorrect_count: usize = sorted_buckets
+        .iter()
+        .filter(|b| !b.is_correct)
+        .map(|b| b.execution_indices.len())
+        .sum();
+
+    writeln!(
+        output,
+        "- **Correct predictions**: {} ({:.1}%)",
+        correct_count,
+        (correct_count as f64 / total_executions as f64) * 100.0
+    )?;
+
+    writeln!(
+        output,
+        "- **Incorrect predictions**: {} ({:.1}%)",
+        incorrect_count,
+        (incorrect_count as f64 / total_executions as f64) * 100.0
+    )?;
+
+    writeln!(
+        output,
+        "- **No Predictions**: {} ({:.1}%)",
+        empty_predictions.len(),
+        (empty_predictions.len() as f64 / total_executions as f64) * 100.0
+    )?;
+
+    let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
+    writeln!(
+        output,
+        "- **Unique incorrect edit patterns**: {}\n",
+        unique_incorrect
+    )?;
+
+    writeln!(output, "---\n")?;
+
+    for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
+        if idx == 0 {
+            writeln!(
+                output,
+                "## Correct Predictions ({} occurrences)\n",
+                bucket.execution_indices.len()
+            )?;
+        }
+
+        writeln!(output, "**Predicted Edit:**\n")?;
+        writeln!(output, "```diff")?;
+        writeln!(output, "{}", bucket.diff)?;
+        writeln!(output, "```\n")?;
+
+        writeln!(
+            output,
+            "**Executions:** {}\n",
+            bucket.execution_indices.join(", ")
+        )?;
+        writeln!(output, "---\n")?;
+    }
+
+    for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
+        writeln!(
+            output,
+            "## Incorrect Prediction #{} ({} occurrences)\n",
+            idx + 1,
+            bucket.execution_indices.len()
+        )?;
+
+        writeln!(output, "**Predicted Edit:**\n")?;
+        writeln!(output, "```diff")?;
+        writeln!(output, "{}", bucket.diff)?;
+        writeln!(output, "```\n")?;
+
+        writeln!(
+            output,
+            "**Executions:** {}\n",
+            bucket.execution_indices.join(", ")
+        )?;
+
+        for (exec_id, reasoning) in bucket
+            .execution_indices
+            .iter()
+            .zip(bucket.reasoning_samples.iter())
+        {
+            writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
+        }
+
+        writeln!(output, "\n---\n")?;
+    }
+
+    if !empty_predictions.is_empty() {
+        writeln!(
+            output,
+            "## No Predictions ({} occurrences)\n",
+            empty_predictions.len()
+        )?;
+
+        for execution_data in &empty_predictions {
+            writeln!(
+                output,
+                "{}",
+                fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
+            )?;
+        }
+        writeln!(output, "\n---\n")?;
+    }
+
+    if !errors.is_empty() {
+        writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
+
+        for (err, name, repetition_ix) in &errors {
+            writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
+        }
+        writeln!(output, "\n---\n")?;
+    }
+
+    fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
+        let exec_content = format!(
+            "\n### Execution {} `{}/{}/prediction_response.md`{}",
+            exec_id,
+            crate::paths::RUN_DIR.display(),
+            exec_id,
+            indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
+        );
+        indent_text(&exec_content, 2)
+    }
+
+    fn indent_text(text: &str, spaces: usize) -> String {
+        let indent = " ".repeat(spaces);
+        text.lines()
+            .collect::<Vec<_>>()
+            .join(&format!("\n{}", indent))
+    }
+
+    Ok(())
+}
+
+fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
+    let err = format!("{err:?}")
+        .replace("<edits", "```xml\n<edits")
+        .replace("</edits>", "</edits>\n```");
+    format!(
+        "### ERROR {name}{}\n\n{err}\n",
+        repetition_ix
+            .map(|ix| format!(" [RUN {ix:03}]"))
+            .unwrap_or_default()
+    )
+}