From 71c5e14665ba8cc81c980dc2712dd63bbb5ad7da Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Mon, 26 Jan 2026 14:49:54 +0200 Subject: [PATCH] Add `ep repair` command to address judge feedback (#47646) Release Notes: - N/A --- crates/edit_prediction_cli/src/example.rs | 6 + .../edit_prediction_cli/src/format_prompt.rs | 50 ++- crates/edit_prediction_cli/src/main.rs | 42 ++- crates/edit_prediction_cli/src/predict.rs | 4 +- crates/edit_prediction_cli/src/prompts/qa.md | 43 +++ .../edit_prediction_cli/src/prompts/repair.md | 71 ++++ .../{teacher.prompt.md => prompts/teacher.md} | 0 .../edit_prediction_cli/src/pull_examples.rs | 2 + crates/edit_prediction_cli/src/qa.rs | 87 ++--- crates/edit_prediction_cli/src/repair.rs | 348 ++++++++++++++++++ 10 files changed, 585 insertions(+), 68 deletions(-) create mode 100644 crates/edit_prediction_cli/src/prompts/qa.md create mode 100644 crates/edit_prediction_cli/src/prompts/repair.md rename crates/edit_prediction_cli/src/{teacher.prompt.md => prompts/teacher.md} (100%) create mode 100644 crates/edit_prediction_cli/src/repair.rs diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 0b97c71902eef53a1c893fc3c7d7637344b220b2..fd165a75d8233a1425afd54d2d16b814db9b5e15 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,5 +1,6 @@ use crate::PredictionProvider; use crate::paths::WORKTREES_DIR; +use crate::qa::QaResult; use anyhow::{Context as _, Result}; use collections::HashMap; use edit_prediction::example_spec::ExampleSpec; @@ -41,6 +42,10 @@ pub struct Example { #[serde(default, skip_serializing_if = "Vec::is_empty")] pub score: Vec, + /// QA evaluation results for each prediction (indexed parallel to `predictions`). + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub qa: Vec>, + /// The application state used to process this example. #[serde(skip)] pub state: Option, @@ -253,6 +258,7 @@ fn parse_markdown_example(input: &str) -> Result { prompt: None, predictions: Vec::new(), score: Vec::new(), + qa: Vec::new(), state: None, }) } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index daea43a7254357e9ccc8411218070a2eb2db9568..75559dd2e7689b8937a5ab2e4d71386167e94c7d 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -153,7 +153,7 @@ pub fn zeta2_output_for_patch( pub struct TeacherPrompt; impl TeacherPrompt { - const PROMPT: &str = include_str!("teacher.prompt.md"); + const PROMPT: &str = include_str!("prompts/teacher.md"); pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n"; pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>"; pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>"; @@ -249,7 +249,7 @@ impl TeacherPrompt { history_lines.join("\n") } - fn format_context(example: &Example) -> String { + pub fn format_context(example: &Example) -> String { let related_files = example .prompt_inputs .as_ref() @@ -331,6 +331,52 @@ impl TeacherPrompt { } } +/// Extract the cursor excerpt from an example. +/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs. +pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option { + // If we have the original prompt, extract the cursor excerpt from it + if let Some(prompt) = &example.prompt { + // Find "# 3. Current File" section and extract the content + if let Some(start) = prompt.input.find("# 3. Current File") { + let content_start = prompt.input[start..].find('`').map(|i| start + i)?; + let backtick_count = prompt.input[content_start..] + .chars() + .take_while(|&c| c == '`') + .count(); + let content_start = content_start + backtick_count; + + // Find the path line and skip it + let newline_pos = prompt.input[content_start..].find('\n')?; + let text_start = content_start + newline_pos + 1; + + // Find the closing backticks + let closing_pattern = "`".repeat(backtick_count); + let text_end = prompt.input[text_start..].find(&closing_pattern)?; + let cursor_excerpt = &prompt.input[text_start..text_start + text_end]; + + let path_str = example.spec.cursor_path.to_string_lossy(); + return Some(format!("`````{path_str}\n{cursor_excerpt}`````")); + } + } + + // Fallback: construct from prompt_inputs if available + let prompt_inputs = example.prompt_inputs.as_ref()?; + let content = &prompt_inputs.content; + let cursor_offset = prompt_inputs.cursor_offset; + + // Simple fallback: just show content around cursor with markers + let path_str = example.spec.cursor_path.to_string_lossy(); + let mut result = format!("`````{path_str}\n"); + result.push_str(TeacherPrompt::EDITABLE_REGION_START); + result.push_str(&content[..cursor_offset]); + result.push_str(TeacherPrompt::USER_CURSOR_MARKER); + result.push_str(&content[cursor_offset..]); + result.push_str(TeacherPrompt::EDITABLE_REGION_END); + result.push_str("\n`````"); + + Some(result) +} + fn extract_last_codeblock(text: &str) -> String { let mut last_block = None; let mut search_start = 0; diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 036743b2373afa173981f2903e899fae92841d6d..79a334a5e874eaa783581726afdc699768d360a7 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -14,6 +14,7 @@ mod progress; mod pull_examples; mod qa; mod reorder_patch; +mod repair; mod retrieve_context; mod score; mod split_commit; @@ -169,6 +170,8 @@ enum Command { ImportBatch(ImportBatchArgs), /// Assess the quality of predictions using LLM-as-a-judge Qa(qa::QaArgs), + /// Repair predictions that received poor QA scores by generating improved predictions + Repair(repair::RepairArgs), } impl Display for Command { @@ -207,6 +210,9 @@ impl Display for Command { Command::Qa(_) => { write!(f, "qa") } + Command::Repair(_) => { + write!(f, "repair") + } } } } @@ -242,6 +248,7 @@ enum PredictionProvider { Zeta2(ZetaVersion), Teacher(ZetaVersion), TeacherNonBatching(ZetaVersion), + Repair, } impl Default for PredictionProvider { @@ -261,6 +268,7 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::TeacherNonBatching(version) => { write!(f, "teacher-non-batching:{version}") } + PredictionProvider::Repair => write!(f, "repair"), } } } @@ -285,9 +293,10 @@ impl std::str::FromStr for PredictionProvider { "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => { Ok(PredictionProvider::TeacherNonBatching(version)) } + "repair" => Ok(PredictionProvider::Repair), _ => { anyhow::bail!( - "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher-non-batching\n\ + "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher-non-batching, repair\n\ For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ Available zeta versions:\n{}", ZetaVersion::options_as_string() @@ -617,6 +626,34 @@ fn main() { }); return; } + Command::Repair(repair_args) => { + // Read examples from input files + let mut examples = example::read_example_files(&args.inputs); + + // Apply filters + if let Some(name_filter) = &args.name { + examples.retain(|e| e.spec.name.contains(name_filter)); + } + if let Some(repo_filter) = &args.repo { + examples.retain(|e| e.spec.repository_url.contains(repo_filter)); + } + if let Some(offset) = args.offset { + examples.splice(0..offset, []); + } + if let Some(limit) = args.limit { + examples.truncate(limit); + } + + smol::block_on(async { + if let Err(e) = + repair::run_repair(&mut examples, repair_args, output.as_ref()).await + { + eprintln!("Error: {:?}", e); + std::process::exit(1); + } + }); + return; + } _ => {} } @@ -784,7 +821,8 @@ fn main() { | Command::Split(_) | Command::FilterLanguages(_) | Command::ImportBatch(_) - | Command::Qa(_) => { + | Command::Qa(_) + | Command::Repair(_) => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 1c14327bb24b54ab2b7e4d980a3f2f4c88488150..79669df01078269ca28d4ed9a2a17cfc2f0edfb1 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -107,7 +107,9 @@ pub async fn run_prediction( } PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, - PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => { + PredictionProvider::Teacher(..) + | PredictionProvider::TeacherNonBatching(..) + | PredictionProvider::Repair => { unreachable!() } }; diff --git a/crates/edit_prediction_cli/src/prompts/qa.md b/crates/edit_prediction_cli/src/prompts/qa.md new file mode 100644 index 0000000000000000000000000000000000000000..3964b8544b749445ed954f21b8e5060e44587f54 --- /dev/null +++ b/crates/edit_prediction_cli/src/prompts/qa.md @@ -0,0 +1,43 @@ +You are evaluating an edit prediction model for a code editor. The model observes a programmer's recent edit history and predicts what edit they will make next. + +All diffs are in the word-diff format. + +The model is instructed to: +- Complete partially-applied refactoring or changes +- Maintain consistency with established patterns and style +- NOT delete or revert text that was just added (unless the user explicitly undid it themselves) + +## Edit History (chronological) +``````` +{edit_history} +``````` + +## Current File +The file where the prediction will be applied, with editable region markers showing where edits can occur: +{cursor_excerpt} + +## Predicted Next Edit +``````` +{actual_patch_word_diff} +``````` + +## Evaluate + +1. **reverts_edits**: Does the prediction undo, or revert changes the user intentionally made in the **edit history**? + +2. **confidence**: How likely is the user to accept this suggestion? + - 1 = Definitely reject (wrong, nonsensical, or harmful) + - 2 = Probably reject (doesn't fit intent or pattern) + - 3 = Uncertain (plausible but not clearly correct) + - 4 = Probably accept (reasonable next step) + - 5 = Definitely accept (obvious continuation) + +Output JSON in this format: + +``` +{ + "reasoning": "your reasoning here", + "reverts_edits": true/false, + "confidence": 1-5 +} +``` diff --git a/crates/edit_prediction_cli/src/prompts/repair.md b/crates/edit_prediction_cli/src/prompts/repair.md new file mode 100644 index 0000000000000000000000000000000000000000..3fb32cc5f5abdbf3e03a8c08f125096da498d7da --- /dev/null +++ b/crates/edit_prediction_cli/src/prompts/repair.md @@ -0,0 +1,71 @@ +# Instructions + +You are an edit prediction assistant in a code editor. Your task is to generate an improved prediction based on feedback from a quality assessment. + +A previous model generated a prediction that was judged to have issues. Your job is to generate a better prediction that addresses the feedback. + +## Focus on + +- Completing any partially-applied changes made +- Ensuring consistency with the programming style and patterns already established +- Making edits that maintain or improve code quality +- NOT reverting or undoing changes the user intentionally made + +## Rules + +- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. +- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. +- Keep existing formatting unless it's absolutely necessary +- Don't write a lot of code if you're not sure what to do +- Do not delete or remove text that was just added in the edit history. If a recent edit introduces incomplete or incorrect code, finish or fix it in place, or simply do nothing rather than removing it. Only remove a recent edit if the history explicitly shows the user undoing it themselves. + +# Input Format + +You will be provided with: +1. The user's *edit history*, in chronological order. Use this to infer the user's trajectory and predict the next most logical edit. +2. A set of *related excerpts* from the user's codebase. Some of these may be needed for correctly predicting the next edit. + - `…` may appear within a related file to indicate that some code has been skipped. +3. An excerpt from the user's *current file*. + - Within the user's current file, there is an *editable region* delimited by the `<|editable_region_start|>` and `<|editable_region_end|>` tags. You can only predict edits in this region. + - The `<|user_cursor|>` tag marks the user's current cursor position, as it stands after the last edit in the history. +4. The *previous prediction* that was generated and needs improvement. +5. *Quality assessment feedback* explaining why the previous prediction was problematic. + +# Output Format + +- Briefly explain what was wrong with the previous prediction and how you'll improve it. +- Output the entire editable region, applying the edits that you predict the user will make next. +- If you're unsure about some portion of the next edit, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in. +- Wrap the edited code in a codeblock with exactly five backticks. + +# 1. User Edits History + +````` +{edit_history} +````` + +# 2. Related excerpts + +{context} + +# 3. Current File + +{cursor_excerpt} + +# 4. Previous Prediction (needs improvement) + +The previous model generated the following edit (in word-diff format): + +````` +{actual_patch_word_diff} +````` + +# 5. Quality Assessment Feedback + +- **Reverts user edits**: {reverts_edits} +- **Confidence score**: {confidence}/5 +- **Reasoning**: {qa_reasoning} + +# Your Improved Prediction + +Based on the feedback above, generate an improved prediction. Address the issues identified in the quality assessment. diff --git a/crates/edit_prediction_cli/src/teacher.prompt.md b/crates/edit_prediction_cli/src/prompts/teacher.md similarity index 100% rename from crates/edit_prediction_cli/src/teacher.prompt.md rename to crates/edit_prediction_cli/src/prompts/teacher.md diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index c4928f076db1f0d98031395b6a823cec4ef52062..fadc724f067d5f1cc907202894cf798f6d78bab3 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -231,6 +231,7 @@ fn examples_from_response( prompt: None, predictions: Vec::new(), score: Vec::new(), + qa: Vec::new(), state: None, }), Err(error) => { @@ -756,6 +757,7 @@ fn build_rejected_example( prompt: None, predictions: Vec::new(), score: Vec::new(), + qa: Vec::new(), state: None, } } diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs index 6c30c6f49b85ae5bb7521db68bf99d932e413a3a..f5005e08ae9db7b9c9b4d650b46af79f1223073a 100644 --- a/crates/edit_prediction_cli/src/qa.rs +++ b/crates/edit_prediction_cli/src/qa.rs @@ -5,20 +5,19 @@ use crate::anthropic_client::AnthropicClient; use crate::example::Example; -use crate::paths::CACHE_DIR; +use crate::format_prompt::extract_cursor_excerpt_from_example; +use crate::paths::LLM_CACHE_DB; use crate::word_diff::unified_to_word_diff; use anthropic::{Message, RequestContent, Role}; use anyhow::Result; use serde::{Deserialize, Serialize}; use std::io::{BufWriter, Write}; use std::path::PathBuf; -use std::sync::LazyLock; /// Model to use for QA evaluation. const MODEL: &str = "claude-sonnet-4-5"; -/// Path to the QA cache database. -pub static QA_CACHE_DB: LazyLock = LazyLock::new(|| CACHE_DIR.join("qa_cache.sqlite")); +const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md"); /// Arguments for the QA command. #[derive(Debug, Clone, clap::Args)] @@ -64,6 +63,9 @@ pub fn build_prompt(example: &Example) -> Option { let actual_patch_word_diff = unified_to_word_diff(actual_patch); + // Format cursor excerpt (reuse from format_prompt) + let cursor_excerpt = extract_cursor_excerpt_from_example(example)?; + let mut edit_history = String::new(); for event in &prompt_inputs.edit_history { match event.as_ref() { @@ -83,49 +85,12 @@ pub fn build_prompt(example: &Example) -> Option { } } - Some(format!( - r#" -You are evaluating an edit prediction model for a code editor. The model observes a programmer's recent edit history and predicts what edit they will make next. - -All diffs are in the word-diff format. - -The model is instructed to: -- Complete partially-applied refactoring or changes -- Maintain consistency with established patterns and style -- NOT delete or revert text that was just added (unless the user explicitly undid it themselves) - -## Edit History (chronological) -``````` -{edit_history} -``````` - -## Predicted Next Edit -``````` -{actual_patch_word_diff} -``````` - -## Evaluate - -1. **reverts_edits**: Does the prediction undo, or revert changes the user intentionally made in the **edit history**? - -2. **confidence**: How likely is the user to accept this suggestion? - - 1 = Definitely reject (wrong, nonsensical, or harmful) - - 2 = Probably reject (doesn't fit intent or pattern) - - 3 = Uncertain (plausible but not clearly correct) - - 4 = Probably accept (reasonable next step) - - 5 = Definitely accept (obvious continuation) - -Output JSON in this format: - -``` -{{ - "reasoning": "your reasoning here", - "reverts_edits": true/false, - "confidence": 1-5 -}} -``` -"# - )) + Some( + PROMPT_TEMPLATE + .replace("{edit_history}", &edit_history) + .replace("{cursor_excerpt}", &cursor_excerpt) + .replace("{actual_patch_word_diff}", &actual_patch_word_diff), + ) } /// Extract a code block from a response. @@ -191,7 +156,7 @@ pub async fn run_qa( let client = if args.no_batch { AnthropicClient::plain()? } else { - AnthropicClient::batch(&QA_CACHE_DB)? + AnthropicClient::batch(&LLM_CACHE_DB)? }; eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch); @@ -356,26 +321,22 @@ pub async fn run_qa( continue; } - let result = results_by_idx - .get(&idx) - .cloned() - .unwrap_or_else(|| QaResult { - reasoning: None, - reverts_edits: None, - confidence: None, - response: None, - error: Some("Result not found".to_string()), - }); + let result = results_by_idx.get(&idx).cloned(); - if result.reverts_edits == Some(true) { + if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) { num_reverts_edits += 1; } num_total += 1; - // Add QA result to example and output - let mut example_json = serde_json::to_value(&example)?; - example_json["qa"] = serde_json::to_value(&result)?; - writeln!(writer, "{}", serde_json::to_string(&example_json)?)?; + // Populate QA results for each prediction (currently only first prediction is evaluated) + example.qa = example + .predictions + .iter() + .enumerate() + .map(|(i, _)| if i == 0 { result.clone() } else { None }) + .collect(); + + writeln!(writer, "{}", serde_json::to_string(&example)?)?; } if let Some(path) = output_path { diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs new file mode 100644 index 0000000000000000000000000000000000000000..a27205e131b19bfca7f3cedce6a1f01028b863bb --- /dev/null +++ b/crates/edit_prediction_cli/src/repair.rs @@ -0,0 +1,348 @@ +//! Repair predictions that received poor QA scores. +//! +//! This module takes examples with predictions and QA feedback, identifies +//! predictions that need improvement (based on reverts_edits or low confidence), +//! and uses an LLM to generate improved predictions. + +use crate::PredictionProvider; +use crate::anthropic_client::AnthropicClient; +use crate::example::{Example, ExamplePrediction}; +use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example}; +use crate::paths::LLM_CACHE_DB; +use crate::word_diff::unified_to_word_diff; +use anthropic::{Message, RequestContent, Role}; +use anyhow::Result; +use std::io::{BufWriter, Write}; +use std::path::PathBuf; + +/// Model to use for repair. +const MODEL: &str = "claude-sonnet-4-5"; + +const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md"); + +/// Arguments for the repair command. +#[derive(Debug, Clone, clap::Args)] +pub struct RepairArgs { + /// Use synchronous API instead of batch + #[clap(long)] + pub no_batch: bool, + + /// Wait for batch to complete (polls every 30s) + #[clap(long)] + pub wait: bool, + + /// Confidence threshold: repair predictions with confidence <= this value (1-5) + #[clap(long, default_value = "2")] + pub confidence_threshold: u8, +} + +/// Build the repair prompt for an example that needs improvement. +/// +/// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs). +pub fn build_repair_prompt(example: &Example) -> Option { + let prediction = example.predictions.first()?; + let qa = example.qa.first()?.as_ref()?; + let prompt_inputs = example.prompt_inputs.as_ref()?; + let actual_patch = prediction.actual_patch.as_ref()?; + + let actual_patch_word_diff = unified_to_word_diff(actual_patch); + + // Format edit history similar to qa.rs + let mut edit_history = String::new(); + for event in &prompt_inputs.edit_history { + match event.as_ref() { + zeta_prompt::Event::BufferChange { + path, + old_path, + diff, + predicted: _, + in_open_source_repo: _, + } => { + edit_history.push_str(&format!("--- a{}\n", old_path.display())); + edit_history.push_str(&format!("+++ b{}\n", path.display())); + let diff_word_diff = unified_to_word_diff(diff); + edit_history.push_str(&diff_word_diff); + edit_history.push_str("\n\n"); + } + } + } + + // Format related files context (reuse from TeacherPrompt) + let context = TeacherPrompt::format_context(example); + + // Format cursor excerpt with editable region markers (reuse from format_prompt) + let cursor_excerpt = extract_cursor_excerpt_from_example(example)?; + + // Get QA feedback + let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided"); + let reverts_edits = qa + .reverts_edits + .map_or("unknown", |v| if v { "yes" } else { "no" }); + let confidence = qa + .confidence + .map_or("unknown".to_string(), |v| v.to_string()); + + Some( + PROMPT_TEMPLATE + .replace("{edit_history}", &edit_history) + .replace("{context}", &context) + .replace("{cursor_excerpt}", &cursor_excerpt) + .replace("{actual_patch_word_diff}", &actual_patch_word_diff) + .replace("{reverts_edits}", reverts_edits) + .replace("{confidence}", &confidence) + .replace("{qa_reasoning}", qa_reasoning), + ) +} + +/// Check if an example needs repair based on QA feedback. +pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool { + let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) else { + return false; + }; + + // Repair if reverts_edits is true + if qa.reverts_edits == Some(true) { + return true; + } + + // Repair if confidence is at or below threshold + if let Some(confidence) = qa.confidence { + if confidence <= confidence_threshold { + return true; + } + } + + false +} + +/// Parse the repair response into a prediction. +fn parse_repair_response(example: &Example, response_text: &str) -> Result { + let actual_patch = TeacherPrompt::parse(example, response_text)?; + + Ok(ExamplePrediction { + actual_patch: Some(actual_patch), + actual_output: response_text.to_string(), + error: None, + provider: PredictionProvider::Repair, + }) +} + +/// Run the repair process on a set of examples. +pub async fn run_repair( + examples: &mut [Example], + args: &RepairArgs, + output_path: Option<&PathBuf>, +) -> Result<()> { + let client = if args.no_batch { + AnthropicClient::plain()? + } else { + AnthropicClient::batch(&LLM_CACHE_DB)? + }; + + eprintln!( + "Using model: {}, batching: {}, confidence_threshold: {}", + MODEL, !args.no_batch, args.confidence_threshold + ); + + // First pass: identify examples that need repair and build prompts + let mut repair_items: Vec<(usize, String)> = Vec::new(); + let mut skipped_missing_data = 0; + let mut skipped_no_repair_needed = 0; + + for (idx, example) in examples.iter().enumerate() { + // Skip if missing predictions or qa + if example.predictions.is_empty() || example.qa.is_empty() { + skipped_missing_data += 1; + continue; + } + + // Skip if doesn't need repair + if !needs_repair(example, args.confidence_threshold) { + skipped_no_repair_needed += 1; + continue; + } + + // Build repair prompt + let Some(prompt) = build_repair_prompt(example) else { + skipped_missing_data += 1; + continue; + }; + + repair_items.push((idx, prompt)); + } + + eprintln!( + "Skipping {} items with missing data, {} items that don't need repair", + skipped_missing_data, skipped_no_repair_needed + ); + eprintln!("{} items to repair", repair_items.len()); + + // Process all items + let mut results: Vec<(usize, Option)> = Vec::new(); + + if args.no_batch { + // Synchronous processing + for (i, (idx, prompt)) in repair_items.iter().enumerate() { + eprint!("\rProcessing {}/{}", i + 1, repair_items.len()); + + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let response = client.generate(MODEL, 16384, messages).await?; + let result = response.map(|r| { + r.content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") + }); + results.push((*idx, result)); + } + eprintln!(); + } else { + // Queue all for batching + for (idx, prompt) in &repair_items { + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let response = client.generate(MODEL, 16384, messages).await?; + let result = response.map(|r| { + r.content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") + }); + results.push((*idx, result)); + } + + // Sync batches (upload pending, download finished) + client.sync_batches().await?; + + if args.wait { + eprintln!("Waiting for batch to complete..."); + loop { + std::thread::sleep(std::time::Duration::from_secs(30)); + client.sync_batches().await?; + + // Re-check all items that didn't have results + let mut all_done = true; + for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() { + if results[result_idx].1.is_none() { + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let response = client.generate(MODEL, 16384, messages).await?; + if let Some(r) = response { + let text = r + .content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => { + Some(text.as_str()) + } + _ => None, + }) + .collect::>() + .join(""); + results[result_idx] = (*idx, Some(text)); + } else { + all_done = false; + } + } + } + + let done_count = results.iter().filter(|(_, r)| r.is_some()).count(); + if all_done { + break; + } + eprintln!( + "Still waiting... {}/{} results", + done_count, + repair_items.len() + ); + } + } else { + let pending_count = results.iter().filter(|(_, r)| r.is_none()).count(); + if pending_count > 0 { + eprintln!( + "Batch submitted. {} pending. Run again later to retrieve results.", + pending_count + ); + } + } + } + + // Build results map by index + let mut results_by_idx: std::collections::HashMap = + std::collections::HashMap::new(); + for (idx, result) in results { + if let Some(r) = result { + results_by_idx.insert(idx, r); + } + } + + // Output results + let mut writer: Box = if let Some(path) = output_path { + Box::new(BufWriter::new(std::fs::File::create(path)?)) + } else { + Box::new(std::io::stdout()) + }; + + let mut num_repaired = 0; + let mut num_repair_errors = 0; + + for (idx, example) in examples.iter_mut().enumerate() { + // Add repair prediction if we have a result + if let Some(response_text) = results_by_idx.get(&idx) { + match parse_repair_response(example, response_text) { + Ok(prediction) => { + example.predictions.push(prediction); + num_repaired += 1; + } + Err(e) => { + // Add error prediction + example.predictions.push(ExamplePrediction { + actual_patch: None, + actual_output: response_text.clone(), + error: Some(format!("Failed to parse repair response: {}", e)), + provider: PredictionProvider::Repair, + }); + num_repair_errors += 1; + } + } + } + + writeln!(writer, "{}", serde_json::to_string(&example)?)?; + } + + if let Some(path) = output_path { + eprintln!("Results written to {}", path.display()); + } + + eprintln!("Repaired: {} items", num_repaired); + eprintln!("Repair errors: {} items", num_repair_errors); + + Ok(()) +}