From a1558bc5cf38c83855054928b4722a1be252f5b6 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Mon, 2 Feb 2026 10:32:25 -0600 Subject: [PATCH] ep_cli: Integrate `ep qa` and `ep repair` into core `ep` loop (#47987) Closes #ISSUE Release Notes: - N/A *or* Added/Fixed/Improved ... --- crates/edit_prediction_cli/src/main.rs | 102 ++--- crates/edit_prediction_cli/src/progress.rs | 6 + crates/edit_prediction_cli/src/qa.rs | 387 +++++++----------- crates/edit_prediction_cli/src/repair.rs | 451 ++++++++------------- 4 files changed, 376 insertions(+), 570 deletions(-) diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 3124716791b7fa0468be44be2a78f710ecddf554..18dc4c2d6300f7e1a069aa4ab1ce962d12ac70f2 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -568,8 +568,10 @@ async fn load_examples( // Skip resume logic for --in-place since input and output are the same file, // which would incorrectly treat all input examples as already processed. if !args.in_place { - if let Some(path) = output_path { - resume_from_output(path, &mut examples); + if let Some(path) = output_path + && let Some(command) = &args.command + { + resume_from_output(path, &mut examples, command); } } @@ -594,7 +596,7 @@ fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 { hasher.finish() } -fn resume_from_output(path: &PathBuf, examples: &mut Vec) { +fn resume_from_output(path: &PathBuf, examples: &mut Vec, command: &Command) { let file = match File::open(path) { Ok(f) => f, Err(_) => return, @@ -615,8 +617,22 @@ fn resume_from_output(path: &PathBuf, examples: &mut Vec) { if let Ok(output_example) = serde_json::from_str::(&line) { let hash = spec_hash(&output_example.spec); if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) { - kept_hashes.insert(hash); - kept_lines.push(line); + let is_complete = match command { + Command::Qa(_) => output_example + .qa + .first() + .and_then(|q| q.as_ref()) + .and_then(|q| q.confidence) + .is_some(), + Command::Repair(_) => output_example.predictions.iter().any(|p| { + p.provider == PredictionProvider::Repair && p.actual_patch.is_some() + }), + _ => true, + }; + if is_complete { + kept_hashes.insert(hash); + kept_lines.push(line); + } } } } @@ -745,60 +761,7 @@ fn main() { } return; } - Command::Qa(qa_args) => { - // Read examples from input files - let mut examples = example::read_example_files(&args.inputs); - - // Apply filters - if let Some(name_filter) = &args.name { - examples.retain(|e| e.spec.name.contains(name_filter)); - } - if let Some(repo_filter) = &args.repo { - examples.retain(|e| e.spec.repository_url.contains(repo_filter)); - } - if let Some(offset) = args.offset { - examples.splice(0..offset, []); - } - if let Some(limit) = args.limit { - examples.truncate(limit); - } - - smol::block_on(async { - if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await { - eprintln!("Error: {:?}", e); - std::process::exit(1); - } - }); - return; - } - Command::Repair(repair_args) => { - // Read examples from input files - let mut examples = example::read_example_files(&args.inputs); - // Apply filters - if let Some(name_filter) = &args.name { - examples.retain(|e| e.spec.name.contains(name_filter)); - } - if let Some(repo_filter) = &args.repo { - examples.retain(|e| e.spec.repository_url.contains(repo_filter)); - } - if let Some(offset) = args.offset { - examples.splice(0..offset, []); - } - if let Some(limit) = args.limit { - examples.truncate(limit); - } - - smol::block_on(async { - if let Err(e) = - repair::run_repair(&mut examples, repair_args, output.as_ref()).await - { - eprintln!("Error: {:?}", e); - std::process::exit(1); - } - }); - return; - } _ => {} } @@ -826,6 +789,12 @@ fn main() { Command::Eval(args) => { predict::sync_batches(args.predict.provider.as_ref()).await?; } + Command::Qa(args) => { + qa::sync_batches(args).await?; + } + Command::Repair(args) => { + repair::sync_batches(args).await?; + } _ => (), } @@ -957,14 +926,19 @@ fn main() { ) .await?; } + Command::Qa(args) => { + qa::run_qa(example, args, &example_progress).await?; + } + Command::Repair(args) => { + repair::run_repair(example, args, &example_progress) + .await?; + } Command::Clean | Command::Synthesize(_) | Command::SplitCommit(_) | Command::Split(_) | Command::FilterLanguages(_) - | Command::ImportBatch(_) - | Command::Qa(_) - | Command::Repair(_) => { + | Command::ImportBatch(_) => { unreachable!() } } @@ -1062,6 +1036,12 @@ fn main() { Command::Eval(args) => { predict::sync_batches(args.predict.provider.as_ref()).await?; } + Command::Qa(args) => { + qa::sync_batches(args).await?; + } + Command::Repair(args) => { + repair::sync_batches(args).await?; + } _ => (), } diff --git a/crates/edit_prediction_cli/src/progress.rs b/crates/edit_prediction_cli/src/progress.rs index 0f9a29401ca9bb902694e5a13499cc17827aa04f..ec97f056ce1d9033fa4dc5b2a70e3528ff355545 100644 --- a/crates/edit_prediction_cli/src/progress.rs +++ b/crates/edit_prediction_cli/src/progress.rs @@ -50,6 +50,8 @@ pub enum Step { FormatPrompt, Predict, Score, + Qa, + Repair, Synthesize, PullExamples, } @@ -68,6 +70,8 @@ impl Step { Step::FormatPrompt => "Format", Step::Predict => "Predict", Step::Score => "Score", + Step::Qa => "QA", + Step::Repair => "Repair", Step::Synthesize => "Synthesize", Step::PullExamples => "Pull", } @@ -80,6 +84,8 @@ impl Step { Step::FormatPrompt => "\x1b[34m", Step::Predict => "\x1b[32m", Step::Score => "\x1b[31m", + Step::Qa => "\x1b[36m", + Step::Repair => "\x1b[95m", Step::Synthesize => "\x1b[36m", Step::PullExamples => "\x1b[36m", } diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs index 9a54353040afbe57ba431fc103b4a16f7cbca232..c84d4b5cbe31ced383113c5dfb425c07e5cdc73e 100644 --- a/crates/edit_prediction_cli/src/qa.rs +++ b/crates/edit_prediction_cli/src/qa.rs @@ -3,17 +3,20 @@ //! This module uses LLM Batch APIs to evaluate prediction quality. //! Caching is handled by the underlying client. -use crate::BatchProvider; -use crate::anthropic_client::AnthropicClient; -use crate::example::Example; -use crate::format_prompt::extract_cursor_excerpt_from_example; -use crate::openai_client::OpenAiClient; -use crate::paths::LLM_CACHE_DB; -use crate::word_diff::unified_to_word_diff; -use anyhow::Result; +use crate::{ + BatchProvider, + anthropic_client::AnthropicClient, + example::Example, + format_prompt::extract_cursor_excerpt_from_example, + openai_client::OpenAiClient, + parse_output::run_parse_output, + paths::LLM_CACHE_DB, + progress::{ExampleProgress, Step}, + word_diff::unified_to_word_diff, +}; +use anyhow::{Context as _, Result}; use serde::{Deserialize, Serialize}; -use std::io::{BufWriter, Write}; -use std::path::PathBuf; +use std::sync::OnceLock; /// Arguments for the QA command. #[derive(Debug, Clone, clap::Args)] @@ -22,10 +25,6 @@ pub struct QaArgs { #[clap(long)] pub no_batch: bool, - /// Wait for batch to complete (polls every 30s) - #[clap(long)] - pub wait: bool, - /// Which LLM provider to use (anthropic or openai) #[clap(long, default_value = "openai")] pub backend: BatchProvider, @@ -63,15 +62,24 @@ pub struct QaResult { } /// Build the assessment prompt for an example. -pub fn build_prompt(example: &Example) -> Option { - let prediction = example.predictions.first()?; - let actual_patch = prediction.actual_patch.as_ref()?; - let prompt_inputs = example.prompt_inputs.as_ref()?; +pub fn build_prompt(example: &Example) -> Result { + let prediction = example + .predictions + .first() + .context("no predictions available")?; + let actual_patch = prediction + .actual_patch + .as_ref() + .context("no actual_patch available (run predict first)")?; + let prompt_inputs = example + .prompt_inputs + .as_ref() + .context("prompt_inputs missing (run context retrieval first)")?; let actual_patch_word_diff = unified_to_word_diff(actual_patch); - // Format cursor excerpt (reuse from format_prompt) - let cursor_excerpt = extract_cursor_excerpt_from_example(example)?; + let cursor_excerpt = + extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?; let mut edit_history = String::new(); for event in &prompt_inputs.edit_history { @@ -93,15 +101,12 @@ pub fn build_prompt(example: &Example) -> Option { } let prompt_template = crate::prompt_assets::get_prompt("qa.md"); - Some( - prompt_template - .replace("{edit_history}", &edit_history) - .replace("{cursor_excerpt}", &cursor_excerpt) - .replace("{actual_patch_word_diff}", &actual_patch_word_diff), - ) + Ok(prompt_template + .replace("{edit_history}", &edit_history) + .replace("{cursor_excerpt}", &cursor_excerpt) + .replace("{actual_patch_word_diff}", &actual_patch_word_diff)) } -/// Extract a code block from a response. fn extract_codeblock(response: &str) -> Option { let lines: Vec<&str> = response.lines().collect(); for (i, line) in lines.iter().enumerate() { @@ -118,11 +123,9 @@ fn extract_codeblock(response: &str) -> Option { None } -/// Parse the LLM response into a QaResult. fn parse_response(response_text: &str) -> QaResult { let codeblock = extract_codeblock(response_text); - // Try parsing codeblock first, then fall back to raw response for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] { let Some(text) = text_to_parse else { continue; @@ -145,7 +148,6 @@ fn parse_response(response_text: &str) -> QaResult { } } - // If all parsing attempts fail, return error QaResult { reasoning: Some(response_text.to_string()), reverts_edits: None, @@ -155,239 +157,148 @@ fn parse_response(response_text: &str) -> QaResult { } } -enum QaClient { - Anthropic(AnthropicClient), - OpenAi(OpenAiClient), -} +static ANTHROPIC_CLIENT_BATCH: OnceLock = OnceLock::new(); +static ANTHROPIC_CLIENT_PLAIN: OnceLock = OnceLock::new(); +static OPENAI_CLIENT_BATCH: OnceLock = OnceLock::new(); +static OPENAI_CLIENT_PLAIN: OnceLock = OnceLock::new(); -impl QaClient { - async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result> { - match self { - QaClient::Anthropic(client) => { - let messages = vec![anthropic::Message { - role: anthropic::Role::User, - content: vec![anthropic::RequestContent::Text { - text: prompt.to_string(), - cache_control: None, - }], - }]; - let response = client - .generate(model, max_tokens, messages, None, false) - .await?; - Ok(response.map(|r| { - r.content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join("") - })) - } - QaClient::OpenAi(client) => { - let messages = vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt.to_string()), - }]; - let response = client - .generate(model, max_tokens, messages, None, false) - .await?; - Ok(response.map(|r| { - r.choices - .into_iter() - .filter_map(|choice| match choice.message { - open_ai::RequestMessage::Assistant { content, .. } => { - content.map(|c| match c { - open_ai::MessageContent::Plain(text) => text, - open_ai::MessageContent::Multipart(parts) => parts - .into_iter() - .filter_map(|p| match p { - open_ai::MessagePart::Text { text } => Some(text), - _ => None, - }) - .collect::>() - .join(""), - }) - } - _ => None, - }) - .collect::>() - .join("") - })) - } - } - } - - async fn sync_batches(&self) -> Result<()> { - match self { - QaClient::Anthropic(client) => client.sync_batches().await, - QaClient::OpenAi(client) => client.sync_batches().await, - } - } -} - -/// Run the QA evaluation on a set of examples. +/// Run QA evaluation for a single example. pub async fn run_qa( - examples: &mut [Example], + example: &mut Example, args: &QaArgs, - output_path: Option<&PathBuf>, + example_progress: &ExampleProgress, ) -> Result<()> { - let model = model_for_backend(args.backend); - let client = match args.backend { - BatchProvider::Anthropic => { - if args.no_batch { - QaClient::Anthropic(AnthropicClient::plain()?) - } else { - QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?) - } - } - BatchProvider::Openai => { - if args.no_batch { - QaClient::OpenAi(OpenAiClient::plain()?) - } else { - QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?) - } - } - }; - - eprintln!( - "Using model: {}, backend: {:?}, batching: {}", - model, args.backend, !args.no_batch - ); - - // First pass: send requests (client handles caching internally) - let mut prompts: Vec<(usize, String)> = Vec::new(); - let mut skipped_count = 0; - - for (idx, example) in examples.iter().enumerate() { - let Some(prompt) = build_prompt(example) else { - skipped_count += 1; - continue; - }; - prompts.push((idx, prompt)); + if example + .qa + .first() + .and_then(|q| q.as_ref()) + .and_then(|q| q.confidence) + .is_some() + { + return Ok(()); } - if skipped_count > 0 { - eprintln!("Skipping {} items with missing actual_patch", skipped_count); + run_parse_output(example).context("Failed to execute run_parse_output")?; + + if example.prompt_inputs.is_none() { + anyhow::bail!("prompt_inputs missing (run context retrieval first)"); } - eprintln!("{} items to process", prompts.len()); + let step_progress = example_progress.start(Step::Qa); - // Process all items - let mut results: Vec<(usize, Option)> = Vec::new(); + let model = model_for_backend(args.backend); + let prompt = build_prompt(example).context("Failed to build QA prompt")?; - if args.no_batch { - // Synchronous processing - for (i, (idx, prompt)) in prompts.iter().enumerate() { - eprint!("\rProcessing {}/{}", i + 1, prompts.len()); + step_progress.set_substatus("generating"); - let response = client.generate(model, 1024, prompt).await?; - let result = response.map(|text| parse_response(&text)); - results.push((*idx, result)); - } - eprintln!(); - } else { - // Queue all for batching - for (idx, prompt) in &prompts { - let response = client.generate(model, 1024, prompt).await?; - let result = response.map(|text| parse_response(&text)); - results.push((*idx, result)); - } + let response = match args.backend { + BatchProvider::Anthropic => { + let client = if args.no_batch { + ANTHROPIC_CLIENT_PLAIN.get_or_init(|| { + AnthropicClient::plain().expect("Failed to create Anthropic client") + }) + } else { + ANTHROPIC_CLIENT_BATCH.get_or_init(|| { + AnthropicClient::batch(&LLM_CACHE_DB) + .expect("Failed to create Anthropic client") + }) + }; - // Sync batches (upload pending, download finished) - client.sync_batches().await?; - - if args.wait { - eprintln!("Waiting for batch to complete..."); - loop { - std::thread::sleep(std::time::Duration::from_secs(30)); - client.sync_batches().await?; - - // Re-check all items that didn't have results - let mut all_done = true; - for (result_idx, (idx, prompt)) in prompts.iter().enumerate() { - if results[result_idx].1.is_none() { - let response = client.generate(model, 1024, prompt).await?; - if let Some(text) = response { - results[result_idx] = (*idx, Some(parse_response(&text))); - } else { - all_done = false; - } - } - } + let messages = vec![anthropic::Message { + role: anthropic::Role::User, + content: vec![anthropic::RequestContent::Text { + text: prompt, + cache_control: None, + }], + }]; - let done_count = results.iter().filter(|(_, r)| r.is_some()).count(); - if all_done { - break; - } - eprintln!("Still waiting... {}/{} results", done_count, prompts.len()); - } - } else { - let pending_count = results.iter().filter(|(_, r)| r.is_none()).count(); - if pending_count > 0 { - eprintln!( - "Batch submitted. {} pending. Run again later to retrieve results.", - pending_count - ); - } - } - } + let Some(response) = client.generate(model, 1024, messages, None, false).await? else { + return Ok(()); + }; - // Build results map by index - let mut results_by_idx: std::collections::HashMap = - std::collections::HashMap::new(); - for (idx, result) in results { - if let Some(r) = result { - results_by_idx.insert(idx, r); + response + .content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") } - } + BatchProvider::Openai => { + let client = if args.no_batch { + OPENAI_CLIENT_PLAIN + .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client")) + } else { + OPENAI_CLIENT_BATCH.get_or_init(|| { + OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client") + }) + }; - // Output results - let mut writer: Box = if let Some(path) = output_path { - Box::new(BufWriter::new(std::fs::File::create(path)?)) - } else { - Box::new(std::io::stdout()) - }; + let messages = vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }]; - let mut num_total = 0; - let mut num_reverts_edits = 0; + let Some(response) = client.generate(model, 1024, messages, None, false).await? else { + return Ok(()); + }; - for (idx, example) in examples.iter_mut().enumerate() { - // Skip examples that couldn't be processed - if build_prompt(example).is_none() { - continue; + response + .choices + .into_iter() + .filter_map(|choice| match choice.message { + open_ai::RequestMessage::Assistant { content, .. } => { + content.map(|c| match c { + open_ai::MessageContent::Plain(text) => text, + open_ai::MessageContent::Multipart(parts) => parts + .into_iter() + .filter_map(|p| match p { + open_ai::MessagePart::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join(""), + }) + } + _ => None, + }) + .collect::>() + .join("") } + }; - let result = results_by_idx.get(&idx).cloned(); - - if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) { - num_reverts_edits += 1; - } - num_total += 1; + let result = parse_response(&response); - // Populate QA results for each prediction (currently only first prediction is evaluated) - example.qa = example - .predictions - .iter() - .enumerate() - .map(|(i, _)| if i == 0 { result.clone() } else { None }) - .collect(); + example.qa = example + .predictions + .iter() + .enumerate() + .map(|(i, _)| if i == 0 { Some(result.clone()) } else { None }) + .collect(); - writeln!(writer, "{}", serde_json::to_string(&example)?)?; - } + Ok(()) +} - if let Some(path) = output_path { - eprintln!("Results written to {}", path.display()); +/// Sync batches for QA (upload pending requests, download finished results). +pub async fn sync_batches(args: &QaArgs) -> Result<()> { + if args.no_batch { + return Ok(()); } - eprintln!("Processed: {} items", num_total); - if num_total > 0 { - eprintln!( - "Reverts edits: {} ({:.2}%)", - num_reverts_edits, - num_reverts_edits as f64 / num_total as f64 * 100.0 - ); + match args.backend { + BatchProvider::Anthropic => { + let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| { + AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client") + }); + client.sync_batches().await?; + } + BatchProvider::Openai => { + let client = OPENAI_CLIENT_BATCH.get_or_init(|| { + OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client") + }); + client.sync_batches().await?; + } } - Ok(()) } diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index 23b89133ae6183ee9444be7c1da21668351f8d2d..9f1504d530f60b90d2ad8171bd79a08c8067eae6 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -4,17 +4,19 @@ //! predictions that need improvement (based on reverts_edits or low confidence), //! and uses an LLM to generate improved predictions. -use crate::BatchProvider; -use crate::PredictionProvider; -use crate::anthropic_client::AnthropicClient; -use crate::example::{Example, ExamplePrediction}; -use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example}; -use crate::openai_client::OpenAiClient; -use crate::paths::LLM_CACHE_DB; -use crate::word_diff::unified_to_word_diff; -use anyhow::Result; -use std::io::{BufWriter, Write}; -use std::path::PathBuf; +use crate::{ + BatchProvider, PredictionProvider, + anthropic_client::AnthropicClient, + example::{Example, ExamplePrediction}, + format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example}, + openai_client::OpenAiClient, + parse_output::run_parse_output, + paths::LLM_CACHE_DB, + progress::{ExampleProgress, Step}, + word_diff::unified_to_word_diff, +}; +use anyhow::{Context as _, Result}; +use std::sync::OnceLock; /// Arguments for the repair command. #[derive(Debug, Clone, clap::Args)] @@ -23,10 +25,6 @@ pub struct RepairArgs { #[clap(long)] pub no_batch: bool, - /// Wait for batch to complete (polls every 30s) - #[clap(long)] - pub wait: bool, - /// Confidence threshold: repair predictions with confidence <= this value (1-5) #[clap(long, default_value = "2")] pub confidence_threshold: u8, @@ -44,17 +42,28 @@ fn model_for_backend(backend: BatchProvider) -> &'static str { } /// Build the repair prompt for an example that needs improvement. -/// -/// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs). -pub fn build_repair_prompt(example: &Example) -> Option { - let prediction = example.predictions.first()?; - let qa = example.qa.first()?.as_ref()?; - let prompt_inputs = example.prompt_inputs.as_ref()?; - let actual_patch = prediction.actual_patch.as_ref()?; +pub fn build_repair_prompt(example: &Example) -> Result { + let prediction = example + .predictions + .first() + .context("no predictions available")?; + let qa = example + .qa + .first() + .context("no QA results available")? + .as_ref() + .context("QA result is None")?; + let prompt_inputs = example + .prompt_inputs + .as_ref() + .context("prompt_inputs missing (run context retrieval first)")?; + let actual_patch = prediction + .actual_patch + .as_ref() + .context("no actual_patch available (run predict first)")?; let actual_patch_word_diff = unified_to_word_diff(actual_patch); - // Format edit history similar to qa.rs let mut edit_history = String::new(); for event in &prompt_inputs.edit_history { match event.as_ref() { @@ -74,13 +83,11 @@ pub fn build_repair_prompt(example: &Example) -> Option { } } - // Format related files context (reuse from TeacherPrompt) let context = TeacherPrompt::format_context(example); - // Format cursor excerpt with editable region markers (reuse from format_prompt) - let cursor_excerpt = extract_cursor_excerpt_from_example(example)?; + let cursor_excerpt = + extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?; - // Get QA feedback let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided"); let reverts_edits = qa .reverts_edits @@ -90,16 +97,14 @@ pub fn build_repair_prompt(example: &Example) -> Option { .map_or("unknown".to_string(), |v| v.to_string()); let prompt_template = crate::prompt_assets::get_prompt("repair.md"); - Some( - prompt_template - .replace("{edit_history}", &edit_history) - .replace("{context}", &context) - .replace("{cursor_excerpt}", &cursor_excerpt) - .replace("{actual_patch_word_diff}", &actual_patch_word_diff) - .replace("{reverts_edits}", reverts_edits) - .replace("{confidence}", &confidence) - .replace("{qa_reasoning}", qa_reasoning), - ) + Ok(prompt_template + .replace("{edit_history}", &edit_history) + .replace("{context}", &context) + .replace("{cursor_excerpt}", &cursor_excerpt) + .replace("{actual_patch_word_diff}", &actual_patch_word_diff) + .replace("{reverts_edits}", reverts_edits) + .replace("{confidence}", &confidence) + .replace("{qa_reasoning}", qa_reasoning)) } /// Check if an example needs repair based on QA feedback. @@ -108,12 +113,10 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool { return false; }; - // Repair if reverts_edits is true if qa.reverts_edits == Some(true) { return true; } - // Repair if confidence is at or below threshold if let Some(confidence) = qa.confidence { if confidence <= confidence_threshold { return true; @@ -123,264 +126,170 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool { false } -/// Parse the repair response into a prediction. -fn parse_repair_response(example: &Example, response_text: &str) -> Result { - let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, response_text)?; - - Ok(ExamplePrediction { - actual_patch: Some(actual_patch), - actual_output: response_text.to_string(), - actual_cursor_offset, - error: None, - provider: PredictionProvider::Repair, - }) +/// Check if an example already has a successful repair prediction. +fn has_successful_repair(example: &Example) -> bool { + example + .predictions + .iter() + .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some()) } -enum RepairClient { - Anthropic(AnthropicClient), - OpenAi(OpenAiClient), -} +static ANTHROPIC_CLIENT_BATCH: OnceLock = OnceLock::new(); +static ANTHROPIC_CLIENT_PLAIN: OnceLock = OnceLock::new(); +static OPENAI_CLIENT_BATCH: OnceLock = OnceLock::new(); +static OPENAI_CLIENT_PLAIN: OnceLock = OnceLock::new(); -impl RepairClient { - async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result> { - match self { - RepairClient::Anthropic(client) => { - let messages = vec![anthropic::Message { - role: anthropic::Role::User, - content: vec![anthropic::RequestContent::Text { - text: prompt.to_string(), - cache_control: None, - }], - }]; - let response = client - .generate(model, max_tokens, messages, None, false) - .await?; - Ok(response.map(|r| { - r.content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join("") - })) - } - RepairClient::OpenAi(client) => { - let messages = vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt.to_string()), - }]; - let response = client - .generate(model, max_tokens, messages, None, false) - .await?; - Ok(response.map(|r| { - r.choices - .into_iter() - .filter_map(|choice| match choice.message { - open_ai::RequestMessage::Assistant { content, .. } => { - content.map(|c| match c { - open_ai::MessageContent::Plain(text) => text, - open_ai::MessageContent::Multipart(parts) => parts - .into_iter() - .filter_map(|p| match p { - open_ai::MessagePart::Text { text } => Some(text), - _ => None, - }) - .collect::>() - .join(""), - }) - } - _ => None, - }) - .collect::>() - .join("") - })) - } - } +/// Run repair for a single example. +pub async fn run_repair( + example: &mut Example, + args: &RepairArgs, + example_progress: &ExampleProgress, +) -> Result<()> { + if has_successful_repair(example) { + return Ok(()); } - async fn sync_batches(&self) -> Result<()> { - match self { - RepairClient::Anthropic(client) => client.sync_batches().await, - RepairClient::OpenAi(client) => client.sync_batches().await, - } + if !needs_repair(example, args.confidence_threshold) { + return Ok(()); } -} -/// Run the repair process on a set of examples. -pub async fn run_repair( - examples: &mut [Example], - args: &RepairArgs, - output_path: Option<&PathBuf>, -) -> Result<()> { + run_parse_output(example).context("Failed to execute run_parse_output")?; + + if example.prompt_inputs.is_none() { + anyhow::bail!("prompt_inputs missing (run context retrieval first)"); + } + + if example.predictions.is_empty() { + anyhow::bail!("no predictions available (run predict first)"); + } + + if example.qa.is_empty() { + anyhow::bail!("no QA results available (run qa first)"); + } + + let step_progress = example_progress.start(Step::Repair); + let model = model_for_backend(args.backend); - let client = match args.backend { + let prompt = build_repair_prompt(example).context("Failed to build repair prompt")?; + + step_progress.set_substatus("generating"); + + let response = match args.backend { BatchProvider::Anthropic => { - if args.no_batch { - RepairClient::Anthropic(AnthropicClient::plain()?) + let client = if args.no_batch { + ANTHROPIC_CLIENT_PLAIN.get_or_init(|| { + AnthropicClient::plain().expect("Failed to create Anthropic client") + }) } else { - RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?) - } + ANTHROPIC_CLIENT_BATCH.get_or_init(|| { + AnthropicClient::batch(&LLM_CACHE_DB) + .expect("Failed to create Anthropic client") + }) + }; + + let messages = vec![anthropic::Message { + role: anthropic::Role::User, + content: vec![anthropic::RequestContent::Text { + text: prompt, + cache_control: None, + }], + }]; + + let Some(response) = client.generate(model, 16384, messages, None, false).await? else { + return Ok(()); + }; + + response + .content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") } BatchProvider::Openai => { - if args.no_batch { - RepairClient::OpenAi(OpenAiClient::plain()?) + let client = if args.no_batch { + OPENAI_CLIENT_PLAIN + .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client")) } else { - RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?) - } + OPENAI_CLIENT_BATCH.get_or_init(|| { + OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client") + }) + }; + + let messages = vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }]; + + let Some(response) = client.generate(model, 16384, messages, None, false).await? else { + return Ok(()); + }; + + response + .choices + .into_iter() + .filter_map(|choice| match choice.message { + open_ai::RequestMessage::Assistant { content, .. } => { + content.map(|c| match c { + open_ai::MessageContent::Plain(text) => text, + open_ai::MessageContent::Multipart(parts) => parts + .into_iter() + .filter_map(|p| match p { + open_ai::MessagePart::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join(""), + }) + } + _ => None, + }) + .collect::>() + .join("") } }; - eprintln!( - "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}", - model, args.backend, !args.no_batch, args.confidence_threshold - ); - - // First pass: identify examples that need repair and build prompts - let mut repair_items: Vec<(usize, String)> = Vec::new(); - let mut skipped_missing_data = 0; - let mut skipped_no_repair_needed = 0; - - for (idx, example) in examples.iter().enumerate() { - // Skip if missing predictions or qa - if example.predictions.is_empty() || example.qa.is_empty() { - skipped_missing_data += 1; - continue; - } - - // Skip if doesn't need repair - if !needs_repair(example, args.confidence_threshold) { - skipped_no_repair_needed += 1; - continue; - } - - // Build repair prompt - let Some(prompt) = build_repair_prompt(example) else { - skipped_missing_data += 1; - continue; - }; + let parse_result = TeacherPrompt::parse(example, &response); + let err = parse_result + .as_ref() + .err() + .map(|e| format!("Failed to parse repair response: {}", e)); - repair_items.push((idx, prompt)); - } + let (actual_patch, actual_cursor_offset) = parse_result.ok().unzip(); - eprintln!( - "Skipping {} items with missing data, {} items that don't need repair", - skipped_missing_data, skipped_no_repair_needed - ); - eprintln!("{} items to repair", repair_items.len()); + example.predictions.push(ExamplePrediction { + actual_patch, + actual_output: response, + actual_cursor_offset: actual_cursor_offset.flatten(), + error: err, + provider: PredictionProvider::Repair, + }); - // Process all items - let mut results: Vec<(usize, Option)> = Vec::new(); + Ok(()) +} +/// Sync batches for repair (upload pending requests, download finished results). +pub async fn sync_batches(args: &RepairArgs) -> Result<()> { if args.no_batch { - // Synchronous processing - for (i, (idx, prompt)) in repair_items.iter().enumerate() { - eprint!("\rProcessing {}/{}", i + 1, repair_items.len()); - - let response = client.generate(model, 16384, prompt).await?; - results.push((*idx, response)); - } - eprintln!(); - } else { - // Queue all for batching - for (idx, prompt) in &repair_items { - let response = client.generate(model, 16384, prompt).await?; - results.push((*idx, response)); - } - - // Sync batches (upload pending, download finished) - client.sync_batches().await?; - - if args.wait { - eprintln!("Waiting for batch to complete..."); - loop { - std::thread::sleep(std::time::Duration::from_secs(30)); - client.sync_batches().await?; - - // Re-check all items that didn't have results - let mut all_done = true; - for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() { - if results[result_idx].1.is_none() { - let response = client.generate(model, 16384, prompt).await?; - if let Some(text) = response { - results[result_idx] = (*idx, Some(text)); - } else { - all_done = false; - } - } - } - - let done_count = results.iter().filter(|(_, r)| r.is_some()).count(); - if all_done { - break; - } - eprintln!( - "Still waiting... {}/{} results", - done_count, - repair_items.len() - ); - } - } else { - let pending_count = results.iter().filter(|(_, r)| r.is_none()).count(); - if pending_count > 0 { - eprintln!( - "Batch submitted. {} pending. Run again later to retrieve results.", - pending_count - ); - } - } + return Ok(()); } - // Build results map by index - let mut results_by_idx: std::collections::HashMap = - std::collections::HashMap::new(); - for (idx, result) in results { - if let Some(r) = result { - results_by_idx.insert(idx, r); + match args.backend { + BatchProvider::Anthropic => { + let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| { + AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client") + }); + client.sync_batches().await?; } - } - - // Output results - let mut writer: Box = if let Some(path) = output_path { - Box::new(BufWriter::new(std::fs::File::create(path)?)) - } else { - Box::new(std::io::stdout()) - }; - - let mut num_repaired = 0; - let mut num_repair_errors = 0; - - for (idx, example) in examples.iter_mut().enumerate() { - // Add repair prediction if we have a result - if let Some(response_text) = results_by_idx.get(&idx) { - match parse_repair_response(example, response_text) { - Ok(prediction) => { - example.predictions.push(prediction); - num_repaired += 1; - } - Err(e) => { - // Add error prediction - example.predictions.push(ExamplePrediction { - actual_patch: None, - actual_output: response_text.clone(), - actual_cursor_offset: None, - error: Some(format!("Failed to parse repair response: {}", e)), - provider: PredictionProvider::Repair, - }); - num_repair_errors += 1; - } - } + BatchProvider::Openai => { + let client = OPENAI_CLIENT_BATCH.get_or_init(|| { + OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client") + }); + client.sync_batches().await?; } - - writeln!(writer, "{}", serde_json::to_string(&example)?)?; } - if let Some(path) = output_path { - eprintln!("Results written to {}", path.display()); - } - - eprintln!("Repaired: {} items", num_repaired); - eprintln!("Repair errors: {} items", num_repair_errors); - Ok(()) }