Run QA for all predictions

Oleksiy Syvokon created

Change summary

crates/edit_prediction_cli/src/predict.rs | 121 +++++++++++++++---------
crates/edit_prediction_cli/src/qa.rs      |  12 ++
crates/edit_prediction_cli/src/repair.rs  |  17 +++
3 files changed, 103 insertions(+), 47 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -10,7 +10,7 @@ use crate::{
     paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
     progress::{ExampleProgress, InfoStyle, Step},
     qa,
-    repair::{build_repair_prompt, needs_repair, parse_repair_response},
+    repair::{build_repair_prompt_for_prediction, needs_repair_qa, parse_repair_response},
     retrieve_context::run_context_retrieval,
 };
 use anyhow::Context as _;
@@ -429,10 +429,13 @@ async fn predict_openai(
 /// Default confidence threshold for repair
 const DEFAULT_REPAIR_CONFIDENCE_THRESHOLD: u8 = 3;
 
-/// Predict using teacher model, then run QA evaluation, and optionally repair
-/// if QA indicates issues (reverts_edits=true or low confidence).
+/// Predict using teacher model, then run QA evaluation on all predictions,
+/// and replace predictions that need repair.
 ///
 /// This is a non-batched flow that processes each step synchronously.
+/// - Predictions that pass QA keep their original Teacher provider
+/// - Predictions that fail QA are replaced with repaired versions (RepairedTeacher provider)
+/// - QA results are not stored because they would be outdated after replacement
 async fn predict_repaired_teacher(
     example: &mut Example,
     backend: TeacherBackend,
@@ -441,65 +444,93 @@ async fn predict_repaired_teacher(
     // Step 1: Run teacher prediction (non-batched for immediate results)
     predict_teacher(example, backend, false, repetition_count).await?;
 
-    // Only proceed with QA/repair for the first prediction
-    let Some(prediction) = example.predictions.first() else {
-        return Ok(());
-    };
-
-    // Skip QA if no actual patch was generated
-    if prediction.actual_patch.is_none() {
+    if example.predictions.is_empty() {
         return Ok(());
     }
 
-    // Step 2: Run QA evaluation
     let batch_provider = match backend {
         TeacherBackend::Sonnet45 => BatchProvider::Anthropic,
         TeacherBackend::Gpt52 => BatchProvider::Openai,
     };
-    let qa_client = LlmClient::new(batch_provider, false)?;
-    let qa_model = model_for_backend(batch_provider);
-
-    let qa_result = if let Some(qa_prompt) = qa::build_prompt(example) {
-        match qa_client.generate(qa_model, 1024, &qa_prompt).await? {
-            Some(response_text) => Some(qa::parse_response(&response_text)),
-            None => None,
+    let llm_client = LlmClient::new(batch_provider, false)?;
+    let model = model_for_backend(batch_provider);
+
+    // Step 2: Run QA for all predictions and repair those that need it
+    let mut final_predictions = Vec::with_capacity(example.predictions.len());
+    let mut final_qa = Vec::with_capacity(example.predictions.len());
+
+    for prediction in &example.predictions {
+        // Skip QA if no actual patch was generated
+        if prediction.actual_patch.is_none() {
+            final_predictions.push(prediction.clone());
+            final_qa.push(None);
+            continue;
         }
-    } else {
-        None
-    };
-
-    // Store QA result
-    example.qa = vec![qa_result.clone()];
 
-    // Step 3: Check if repair is needed and run repair if so
-    if needs_repair(example, DEFAULT_REPAIR_CONFIDENCE_THRESHOLD) {
-        let repair_client = LlmClient::new(batch_provider, false)?;
+        // Run QA evaluation for this prediction
+        let qa_result =
+            if let Some(qa_prompt) = qa::build_prompt_for_prediction(example, prediction) {
+                match llm_client.generate(model, 1024, &qa_prompt).await? {
+                    Some(response_text) => Some(qa::parse_response(&response_text)),
+                    None => None,
+                }
+            } else {
+                None
+            };
 
-        if let Some(repair_prompt) = build_repair_prompt(example) {
-            if let Some(response_text) = repair_client
-                .generate(qa_model, 16384, &repair_prompt)
-                .await?
+        // Check if repair is needed
+        let needs_repair = qa_result
+            .as_ref()
+            .map(|qa| needs_repair_qa(qa, DEFAULT_REPAIR_CONFIDENCE_THRESHOLD))
+            .unwrap_or(false);
+
+        if needs_repair {
+            let qa = qa_result
+                .as_ref()
+                .expect("qa_result must be Some if needs_repair is true");
+            // Step 3: Run repair for this prediction
+            if let Some(repair_prompt) = build_repair_prompt_for_prediction(example, prediction, qa)
             {
-                match parse_repair_response(example, &response_text) {
-                    Ok(mut repaired_prediction) => {
-                        // Mark the prediction as coming from repaired-teacher
-                        repaired_prediction.provider = PredictionProvider::RepairedTeacher(backend);
-                        example.predictions.push(repaired_prediction);
-                    }
-                    Err(e) => {
-                        // Add error prediction if parsing failed
-                        example.predictions.push(ExamplePrediction {
-                            actual_patch: None,
-                            actual_output: response_text,
-                            error: Some(format!("Failed to parse repair response: {}", e)),
-                            provider: PredictionProvider::RepairedTeacher(backend),
-                        });
+                if let Some(response_text) =
+                    llm_client.generate(model, 16384, &repair_prompt).await?
+                {
+                    match parse_repair_response(example, &response_text) {
+                        Ok(mut repaired_prediction) => {
+                            repaired_prediction.provider =
+                                PredictionProvider::RepairedTeacher(backend);
+                            final_predictions.push(repaired_prediction);
+                            final_qa.push(qa_result);
+                        }
+                        Err(e) => {
+                            final_predictions.push(ExamplePrediction {
+                                actual_patch: None,
+                                actual_output: response_text,
+                                error: Some(format!("Failed to parse repair response: {}", e)),
+                                provider: PredictionProvider::RepairedTeacher(backend),
+                            });
+                            final_qa.push(qa_result);
+                        }
                     }
+                } else {
+                    // Repair generation returned None, keep original
+                    final_predictions.push(prediction.clone());
+                    final_qa.push(qa_result);
                 }
+            } else {
+                // Couldn't build repair prompt, keep original
+                final_predictions.push(prediction.clone());
+                final_qa.push(qa_result);
             }
+        } else {
+            // No repair needed, keep original (with Teacher provider)
+            final_predictions.push(prediction.clone());
+            final_qa.push(qa_result);
         }
     }
 
+    example.predictions = final_predictions;
+    example.qa = final_qa;
+
     Ok(())
 }
 

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -4,7 +4,7 @@
 //! Caching is handled by the underlying client.
 
 use crate::BatchProvider;
-use crate::example::Example;
+use crate::example::{Example, ExamplePrediction};
 use crate::format_prompt::extract_cursor_excerpt_from_example;
 use crate::llm_client::{LlmClient, model_for_backend};
 use crate::word_diff::unified_to_word_diff;
@@ -55,9 +55,17 @@ pub struct QaResult {
     pub error: Option<String>,
 }
 
-/// Build the assessment prompt for an example.
+/// Build the assessment prompt for an example (uses first prediction).
 pub fn build_prompt(example: &Example) -> Option<String> {
     let prediction = example.predictions.first()?;
+    build_prompt_for_prediction(example, prediction)
+}
+
+/// Build the assessment prompt for a specific prediction.
+pub fn build_prompt_for_prediction(
+    example: &Example,
+    prediction: &ExamplePrediction,
+) -> Option<String> {
     let actual_patch = prediction.actual_patch.as_ref()?;
     let prompt_inputs = example.prompt_inputs.as_ref()?;
 

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -9,6 +9,7 @@ use crate::PredictionProvider;
 use crate::example::{Example, ExamplePrediction};
 use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
 use crate::llm_client::{LlmClient, model_for_backend};
+use crate::qa::QaResult;
 use crate::word_diff::unified_to_word_diff;
 use anyhow::Result;
 use std::io::{BufWriter, Write};
@@ -42,6 +43,17 @@ pub struct RepairArgs {
 pub fn build_repair_prompt(example: &Example) -> Option<String> {
     let prediction = example.predictions.first()?;
     let qa = example.qa.first()?.as_ref()?;
+    build_repair_prompt_for_prediction(example, prediction, qa)
+}
+
+/// Build the repair prompt for a specific prediction and QA result.
+///
+/// Returns None if the example doesn't have the required data.
+pub fn build_repair_prompt_for_prediction(
+    example: &Example,
+    prediction: &ExamplePrediction,
+    qa: &QaResult,
+) -> Option<String> {
     let prompt_inputs = example.prompt_inputs.as_ref()?;
     let actual_patch = prediction.actual_patch.as_ref()?;
 
@@ -100,6 +112,11 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
         return false;
     };
 
+    needs_repair_qa(qa, confidence_threshold)
+}
+
+/// Check if a QA result indicates repair is needed.
+pub fn needs_repair_qa(qa: &QaResult, confidence_threshold: u8) -> bool {
     // Repair if reverts_edits is true
     if qa.reverts_edits == Some(true) {
         return true;