Detailed changes
@@ -568,8 +568,10 @@ async fn load_examples(
// Skip resume logic for --in-place since input and output are the same file,
// which would incorrectly treat all input examples as already processed.
if !args.in_place {
- if let Some(path) = output_path {
- resume_from_output(path, &mut examples);
+ if let Some(path) = output_path
+ && let Some(command) = &args.command
+ {
+ resume_from_output(path, &mut examples, command);
}
}
@@ -594,7 +596,7 @@ fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
hasher.finish()
}
-fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
+fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
let file = match File::open(path) {
Ok(f) => f,
Err(_) => return,
@@ -615,8 +617,22 @@ fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
let hash = spec_hash(&output_example.spec);
if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
- kept_hashes.insert(hash);
- kept_lines.push(line);
+ let is_complete = match command {
+ Command::Qa(_) => output_example
+ .qa
+ .first()
+ .and_then(|q| q.as_ref())
+ .and_then(|q| q.confidence)
+ .is_some(),
+ Command::Repair(_) => output_example.predictions.iter().any(|p| {
+ p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
+ }),
+ _ => true,
+ };
+ if is_complete {
+ kept_hashes.insert(hash);
+ kept_lines.push(line);
+ }
}
}
}
@@ -745,60 +761,7 @@ fn main() {
}
return;
}
- Command::Qa(qa_args) => {
- // Read examples from input files
- let mut examples = example::read_example_files(&args.inputs);
-
- // Apply filters
- if let Some(name_filter) = &args.name {
- examples.retain(|e| e.spec.name.contains(name_filter));
- }
- if let Some(repo_filter) = &args.repo {
- examples.retain(|e| e.spec.repository_url.contains(repo_filter));
- }
- if let Some(offset) = args.offset {
- examples.splice(0..offset, []);
- }
- if let Some(limit) = args.limit {
- examples.truncate(limit);
- }
-
- smol::block_on(async {
- if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
- eprintln!("Error: {:?}", e);
- std::process::exit(1);
- }
- });
- return;
- }
- Command::Repair(repair_args) => {
- // Read examples from input files
- let mut examples = example::read_example_files(&args.inputs);
- // Apply filters
- if let Some(name_filter) = &args.name {
- examples.retain(|e| e.spec.name.contains(name_filter));
- }
- if let Some(repo_filter) = &args.repo {
- examples.retain(|e| e.spec.repository_url.contains(repo_filter));
- }
- if let Some(offset) = args.offset {
- examples.splice(0..offset, []);
- }
- if let Some(limit) = args.limit {
- examples.truncate(limit);
- }
-
- smol::block_on(async {
- if let Err(e) =
- repair::run_repair(&mut examples, repair_args, output.as_ref()).await
- {
- eprintln!("Error: {:?}", e);
- std::process::exit(1);
- }
- });
- return;
- }
_ => {}
}
@@ -826,6 +789,12 @@ fn main() {
Command::Eval(args) => {
predict::sync_batches(args.predict.provider.as_ref()).await?;
}
+ Command::Qa(args) => {
+ qa::sync_batches(args).await?;
+ }
+ Command::Repair(args) => {
+ repair::sync_batches(args).await?;
+ }
_ => (),
}
@@ -957,14 +926,19 @@ fn main() {
)
.await?;
}
+ Command::Qa(args) => {
+ qa::run_qa(example, args, &example_progress).await?;
+ }
+ Command::Repair(args) => {
+ repair::run_repair(example, args, &example_progress)
+ .await?;
+ }
Command::Clean
| Command::Synthesize(_)
| Command::SplitCommit(_)
| Command::Split(_)
| Command::FilterLanguages(_)
- | Command::ImportBatch(_)
- | Command::Qa(_)
- | Command::Repair(_) => {
+ | Command::ImportBatch(_) => {
unreachable!()
}
}
@@ -1062,6 +1036,12 @@ fn main() {
Command::Eval(args) => {
predict::sync_batches(args.predict.provider.as_ref()).await?;
}
+ Command::Qa(args) => {
+ qa::sync_batches(args).await?;
+ }
+ Command::Repair(args) => {
+ repair::sync_batches(args).await?;
+ }
_ => (),
}
@@ -50,6 +50,8 @@ pub enum Step {
FormatPrompt,
Predict,
Score,
+ Qa,
+ Repair,
Synthesize,
PullExamples,
}
@@ -68,6 +70,8 @@ impl Step {
Step::FormatPrompt => "Format",
Step::Predict => "Predict",
Step::Score => "Score",
+ Step::Qa => "QA",
+ Step::Repair => "Repair",
Step::Synthesize => "Synthesize",
Step::PullExamples => "Pull",
}
@@ -80,6 +84,8 @@ impl Step {
Step::FormatPrompt => "\x1b[34m",
Step::Predict => "\x1b[32m",
Step::Score => "\x1b[31m",
+ Step::Qa => "\x1b[36m",
+ Step::Repair => "\x1b[95m",
Step::Synthesize => "\x1b[36m",
Step::PullExamples => "\x1b[36m",
}
@@ -3,17 +3,20 @@
//! This module uses LLM Batch APIs to evaluate prediction quality.
//! Caching is handled by the underlying client.
-use crate::BatchProvider;
-use crate::anthropic_client::AnthropicClient;
-use crate::example::Example;
-use crate::format_prompt::extract_cursor_excerpt_from_example;
-use crate::openai_client::OpenAiClient;
-use crate::paths::LLM_CACHE_DB;
-use crate::word_diff::unified_to_word_diff;
-use anyhow::Result;
+use crate::{
+ BatchProvider,
+ anthropic_client::AnthropicClient,
+ example::Example,
+ format_prompt::extract_cursor_excerpt_from_example,
+ openai_client::OpenAiClient,
+ parse_output::run_parse_output,
+ paths::LLM_CACHE_DB,
+ progress::{ExampleProgress, Step},
+ word_diff::unified_to_word_diff,
+};
+use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
-use std::io::{BufWriter, Write};
-use std::path::PathBuf;
+use std::sync::OnceLock;
/// Arguments for the QA command.
#[derive(Debug, Clone, clap::Args)]
@@ -22,10 +25,6 @@ pub struct QaArgs {
#[clap(long)]
pub no_batch: bool,
- /// Wait for batch to complete (polls every 30s)
- #[clap(long)]
- pub wait: bool,
-
/// Which LLM provider to use (anthropic or openai)
#[clap(long, default_value = "openai")]
pub backend: BatchProvider,
@@ -63,15 +62,24 @@ pub struct QaResult {
}
/// Build the assessment prompt for an example.
-pub fn build_prompt(example: &Example) -> Option<String> {
- let prediction = example.predictions.first()?;
- let actual_patch = prediction.actual_patch.as_ref()?;
- let prompt_inputs = example.prompt_inputs.as_ref()?;
+pub fn build_prompt(example: &Example) -> Result<String> {
+ let prediction = example
+ .predictions
+ .first()
+ .context("no predictions available")?;
+ let actual_patch = prediction
+ .actual_patch
+ .as_ref()
+ .context("no actual_patch available (run predict first)")?;
+ let prompt_inputs = example
+ .prompt_inputs
+ .as_ref()
+ .context("prompt_inputs missing (run context retrieval first)")?;
let actual_patch_word_diff = unified_to_word_diff(actual_patch);
- // Format cursor excerpt (reuse from format_prompt)
- let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
+ let cursor_excerpt =
+ extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?;
let mut edit_history = String::new();
for event in &prompt_inputs.edit_history {
@@ -93,15 +101,12 @@ pub fn build_prompt(example: &Example) -> Option<String> {
}
let prompt_template = crate::prompt_assets::get_prompt("qa.md");
- Some(
- prompt_template
- .replace("{edit_history}", &edit_history)
- .replace("{cursor_excerpt}", &cursor_excerpt)
- .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
- )
+ Ok(prompt_template
+ .replace("{edit_history}", &edit_history)
+ .replace("{cursor_excerpt}", &cursor_excerpt)
+ .replace("{actual_patch_word_diff}", &actual_patch_word_diff))
}
-/// Extract a code block from a response.
fn extract_codeblock(response: &str) -> Option<String> {
let lines: Vec<&str> = response.lines().collect();
for (i, line) in lines.iter().enumerate() {
@@ -118,11 +123,9 @@ fn extract_codeblock(response: &str) -> Option<String> {
None
}
-/// Parse the LLM response into a QaResult.
fn parse_response(response_text: &str) -> QaResult {
let codeblock = extract_codeblock(response_text);
- // Try parsing codeblock first, then fall back to raw response
for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
let Some(text) = text_to_parse else {
continue;
@@ -145,7 +148,6 @@ fn parse_response(response_text: &str) -> QaResult {
}
}
- // If all parsing attempts fail, return error
QaResult {
reasoning: Some(response_text.to_string()),
reverts_edits: None,
@@ -155,239 +157,148 @@ fn parse_response(response_text: &str) -> QaResult {
}
}
-enum QaClient {
- Anthropic(AnthropicClient),
- OpenAi(OpenAiClient),
-}
+static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
+static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
+static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
+static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
-impl QaClient {
- async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
- match self {
- QaClient::Anthropic(client) => {
- let messages = vec![anthropic::Message {
- role: anthropic::Role::User,
- content: vec![anthropic::RequestContent::Text {
- text: prompt.to_string(),
- cache_control: None,
- }],
- }];
- let response = client
- .generate(model, max_tokens, messages, None, false)
- .await?;
- Ok(response.map(|r| {
- r.content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => Some(text.as_str()),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("")
- }))
- }
- QaClient::OpenAi(client) => {
- let messages = vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt.to_string()),
- }];
- let response = client
- .generate(model, max_tokens, messages, None, false)
- .await?;
- Ok(response.map(|r| {
- r.choices
- .into_iter()
- .filter_map(|choice| match choice.message {
- open_ai::RequestMessage::Assistant { content, .. } => {
- content.map(|c| match c {
- open_ai::MessageContent::Plain(text) => text,
- open_ai::MessageContent::Multipart(parts) => parts
- .into_iter()
- .filter_map(|p| match p {
- open_ai::MessagePart::Text { text } => Some(text),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join(""),
- })
- }
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("")
- }))
- }
- }
- }
-
- async fn sync_batches(&self) -> Result<()> {
- match self {
- QaClient::Anthropic(client) => client.sync_batches().await,
- QaClient::OpenAi(client) => client.sync_batches().await,
- }
- }
-}
-
-/// Run the QA evaluation on a set of examples.
+/// Run QA evaluation for a single example.
pub async fn run_qa(
- examples: &mut [Example],
+ example: &mut Example,
args: &QaArgs,
- output_path: Option<&PathBuf>,
+ example_progress: &ExampleProgress,
) -> Result<()> {
- let model = model_for_backend(args.backend);
- let client = match args.backend {
- BatchProvider::Anthropic => {
- if args.no_batch {
- QaClient::Anthropic(AnthropicClient::plain()?)
- } else {
- QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
- }
- }
- BatchProvider::Openai => {
- if args.no_batch {
- QaClient::OpenAi(OpenAiClient::plain()?)
- } else {
- QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
- }
- }
- };
-
- eprintln!(
- "Using model: {}, backend: {:?}, batching: {}",
- model, args.backend, !args.no_batch
- );
-
- // First pass: send requests (client handles caching internally)
- let mut prompts: Vec<(usize, String)> = Vec::new();
- let mut skipped_count = 0;
-
- for (idx, example) in examples.iter().enumerate() {
- let Some(prompt) = build_prompt(example) else {
- skipped_count += 1;
- continue;
- };
- prompts.push((idx, prompt));
+ if example
+ .qa
+ .first()
+ .and_then(|q| q.as_ref())
+ .and_then(|q| q.confidence)
+ .is_some()
+ {
+ return Ok(());
}
- if skipped_count > 0 {
- eprintln!("Skipping {} items with missing actual_patch", skipped_count);
+ run_parse_output(example).context("Failed to execute run_parse_output")?;
+
+ if example.prompt_inputs.is_none() {
+ anyhow::bail!("prompt_inputs missing (run context retrieval first)");
}
- eprintln!("{} items to process", prompts.len());
+ let step_progress = example_progress.start(Step::Qa);
- // Process all items
- let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
+ let model = model_for_backend(args.backend);
+ let prompt = build_prompt(example).context("Failed to build QA prompt")?;
- if args.no_batch {
- // Synchronous processing
- for (i, (idx, prompt)) in prompts.iter().enumerate() {
- eprint!("\rProcessing {}/{}", i + 1, prompts.len());
+ step_progress.set_substatus("generating");
- let response = client.generate(model, 1024, prompt).await?;
- let result = response.map(|text| parse_response(&text));
- results.push((*idx, result));
- }
- eprintln!();
- } else {
- // Queue all for batching
- for (idx, prompt) in &prompts {
- let response = client.generate(model, 1024, prompt).await?;
- let result = response.map(|text| parse_response(&text));
- results.push((*idx, result));
- }
+ let response = match args.backend {
+ BatchProvider::Anthropic => {
+ let client = if args.no_batch {
+ ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
+ AnthropicClient::plain().expect("Failed to create Anthropic client")
+ })
+ } else {
+ ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
+ AnthropicClient::batch(&LLM_CACHE_DB)
+ .expect("Failed to create Anthropic client")
+ })
+ };
- // Sync batches (upload pending, download finished)
- client.sync_batches().await?;
-
- if args.wait {
- eprintln!("Waiting for batch to complete...");
- loop {
- std::thread::sleep(std::time::Duration::from_secs(30));
- client.sync_batches().await?;
-
- // Re-check all items that didn't have results
- let mut all_done = true;
- for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
- if results[result_idx].1.is_none() {
- let response = client.generate(model, 1024, prompt).await?;
- if let Some(text) = response {
- results[result_idx] = (*idx, Some(parse_response(&text)));
- } else {
- all_done = false;
- }
- }
- }
+ let messages = vec![anthropic::Message {
+ role: anthropic::Role::User,
+ content: vec![anthropic::RequestContent::Text {
+ text: prompt,
+ cache_control: None,
+ }],
+ }];
- let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
- if all_done {
- break;
- }
- eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
- }
- } else {
- let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
- if pending_count > 0 {
- eprintln!(
- "Batch submitted. {} pending. Run again later to retrieve results.",
- pending_count
- );
- }
- }
- }
+ let Some(response) = client.generate(model, 1024, messages, None, false).await? else {
+ return Ok(());
+ };
- // Build results map by index
- let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
- std::collections::HashMap::new();
- for (idx, result) in results {
- if let Some(r) = result {
- results_by_idx.insert(idx, r);
+ response
+ .content
+ .iter()
+ .filter_map(|c| match c {
+ anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
}
- }
+ BatchProvider::Openai => {
+ let client = if args.no_batch {
+ OPENAI_CLIENT_PLAIN
+ .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
+ } else {
+ OPENAI_CLIENT_BATCH.get_or_init(|| {
+ OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
+ })
+ };
- // Output results
- let mut writer: Box<dyn Write> = if let Some(path) = output_path {
- Box::new(BufWriter::new(std::fs::File::create(path)?))
- } else {
- Box::new(std::io::stdout())
- };
+ let messages = vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }];
- let mut num_total = 0;
- let mut num_reverts_edits = 0;
+ let Some(response) = client.generate(model, 1024, messages, None, false).await? else {
+ return Ok(());
+ };
- for (idx, example) in examples.iter_mut().enumerate() {
- // Skip examples that couldn't be processed
- if build_prompt(example).is_none() {
- continue;
+ response
+ .choices
+ .into_iter()
+ .filter_map(|choice| match choice.message {
+ open_ai::RequestMessage::Assistant { content, .. } => {
+ content.map(|c| match c {
+ open_ai::MessageContent::Plain(text) => text,
+ open_ai::MessageContent::Multipart(parts) => parts
+ .into_iter()
+ .filter_map(|p| match p {
+ open_ai::MessagePart::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join(""),
+ })
+ }
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
}
+ };
- let result = results_by_idx.get(&idx).cloned();
-
- if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
- num_reverts_edits += 1;
- }
- num_total += 1;
+ let result = parse_response(&response);
- // Populate QA results for each prediction (currently only first prediction is evaluated)
- example.qa = example
- .predictions
- .iter()
- .enumerate()
- .map(|(i, _)| if i == 0 { result.clone() } else { None })
- .collect();
+ example.qa = example
+ .predictions
+ .iter()
+ .enumerate()
+ .map(|(i, _)| if i == 0 { Some(result.clone()) } else { None })
+ .collect();
- writeln!(writer, "{}", serde_json::to_string(&example)?)?;
- }
+ Ok(())
+}
- if let Some(path) = output_path {
- eprintln!("Results written to {}", path.display());
+/// Sync batches for QA (upload pending requests, download finished results).
+pub async fn sync_batches(args: &QaArgs) -> Result<()> {
+ if args.no_batch {
+ return Ok(());
}
- eprintln!("Processed: {} items", num_total);
- if num_total > 0 {
- eprintln!(
- "Reverts edits: {} ({:.2}%)",
- num_reverts_edits,
- num_reverts_edits as f64 / num_total as f64 * 100.0
- );
+ match args.backend {
+ BatchProvider::Anthropic => {
+ let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
+ AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
+ });
+ client.sync_batches().await?;
+ }
+ BatchProvider::Openai => {
+ let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
+ OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
+ });
+ client.sync_batches().await?;
+ }
}
-
Ok(())
}
@@ -4,17 +4,19 @@
//! predictions that need improvement (based on reverts_edits or low confidence),
//! and uses an LLM to generate improved predictions.
-use crate::BatchProvider;
-use crate::PredictionProvider;
-use crate::anthropic_client::AnthropicClient;
-use crate::example::{Example, ExamplePrediction};
-use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
-use crate::openai_client::OpenAiClient;
-use crate::paths::LLM_CACHE_DB;
-use crate::word_diff::unified_to_word_diff;
-use anyhow::Result;
-use std::io::{BufWriter, Write};
-use std::path::PathBuf;
+use crate::{
+ BatchProvider, PredictionProvider,
+ anthropic_client::AnthropicClient,
+ example::{Example, ExamplePrediction},
+ format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example},
+ openai_client::OpenAiClient,
+ parse_output::run_parse_output,
+ paths::LLM_CACHE_DB,
+ progress::{ExampleProgress, Step},
+ word_diff::unified_to_word_diff,
+};
+use anyhow::{Context as _, Result};
+use std::sync::OnceLock;
/// Arguments for the repair command.
#[derive(Debug, Clone, clap::Args)]
@@ -23,10 +25,6 @@ pub struct RepairArgs {
#[clap(long)]
pub no_batch: bool,
- /// Wait for batch to complete (polls every 30s)
- #[clap(long)]
- pub wait: bool,
-
/// Confidence threshold: repair predictions with confidence <= this value (1-5)
#[clap(long, default_value = "2")]
pub confidence_threshold: u8,
@@ -44,17 +42,28 @@ fn model_for_backend(backend: BatchProvider) -> &'static str {
}
/// Build the repair prompt for an example that needs improvement.
-///
-/// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs).
-pub fn build_repair_prompt(example: &Example) -> Option<String> {
- let prediction = example.predictions.first()?;
- let qa = example.qa.first()?.as_ref()?;
- let prompt_inputs = example.prompt_inputs.as_ref()?;
- let actual_patch = prediction.actual_patch.as_ref()?;
+pub fn build_repair_prompt(example: &Example) -> Result<String> {
+ let prediction = example
+ .predictions
+ .first()
+ .context("no predictions available")?;
+ let qa = example
+ .qa
+ .first()
+ .context("no QA results available")?
+ .as_ref()
+ .context("QA result is None")?;
+ let prompt_inputs = example
+ .prompt_inputs
+ .as_ref()
+ .context("prompt_inputs missing (run context retrieval first)")?;
+ let actual_patch = prediction
+ .actual_patch
+ .as_ref()
+ .context("no actual_patch available (run predict first)")?;
let actual_patch_word_diff = unified_to_word_diff(actual_patch);
- // Format edit history similar to qa.rs
let mut edit_history = String::new();
for event in &prompt_inputs.edit_history {
match event.as_ref() {
@@ -74,13 +83,11 @@ pub fn build_repair_prompt(example: &Example) -> Option<String> {
}
}
- // Format related files context (reuse from TeacherPrompt)
let context = TeacherPrompt::format_context(example);
- // Format cursor excerpt with editable region markers (reuse from format_prompt)
- let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
+ let cursor_excerpt =
+ extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?;
- // Get QA feedback
let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
let reverts_edits = qa
.reverts_edits
@@ -90,16 +97,14 @@ pub fn build_repair_prompt(example: &Example) -> Option<String> {
.map_or("unknown".to_string(), |v| v.to_string());
let prompt_template = crate::prompt_assets::get_prompt("repair.md");
- Some(
- prompt_template
- .replace("{edit_history}", &edit_history)
- .replace("{context}", &context)
- .replace("{cursor_excerpt}", &cursor_excerpt)
- .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
- .replace("{reverts_edits}", reverts_edits)
- .replace("{confidence}", &confidence)
- .replace("{qa_reasoning}", qa_reasoning),
- )
+ Ok(prompt_template
+ .replace("{edit_history}", &edit_history)
+ .replace("{context}", &context)
+ .replace("{cursor_excerpt}", &cursor_excerpt)
+ .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
+ .replace("{reverts_edits}", reverts_edits)
+ .replace("{confidence}", &confidence)
+ .replace("{qa_reasoning}", qa_reasoning))
}
/// Check if an example needs repair based on QA feedback.
@@ -108,12 +113,10 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
return false;
};
- // Repair if reverts_edits is true
if qa.reverts_edits == Some(true) {
return true;
}
- // Repair if confidence is at or below threshold
if let Some(confidence) = qa.confidence {
if confidence <= confidence_threshold {
return true;
@@ -123,264 +126,170 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
false
}
-/// Parse the repair response into a prediction.
-fn parse_repair_response(example: &Example, response_text: &str) -> Result<ExamplePrediction> {
- let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, response_text)?;
-
- Ok(ExamplePrediction {
- actual_patch: Some(actual_patch),
- actual_output: response_text.to_string(),
- actual_cursor_offset,
- error: None,
- provider: PredictionProvider::Repair,
- })
+/// Check if an example already has a successful repair prediction.
+fn has_successful_repair(example: &Example) -> bool {
+ example
+ .predictions
+ .iter()
+ .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
}
-enum RepairClient {
- Anthropic(AnthropicClient),
- OpenAi(OpenAiClient),
-}
+static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
+static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
+static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
+static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
-impl RepairClient {
- async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
- match self {
- RepairClient::Anthropic(client) => {
- let messages = vec![anthropic::Message {
- role: anthropic::Role::User,
- content: vec![anthropic::RequestContent::Text {
- text: prompt.to_string(),
- cache_control: None,
- }],
- }];
- let response = client
- .generate(model, max_tokens, messages, None, false)
- .await?;
- Ok(response.map(|r| {
- r.content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => Some(text.as_str()),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("")
- }))
- }
- RepairClient::OpenAi(client) => {
- let messages = vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt.to_string()),
- }];
- let response = client
- .generate(model, max_tokens, messages, None, false)
- .await?;
- Ok(response.map(|r| {
- r.choices
- .into_iter()
- .filter_map(|choice| match choice.message {
- open_ai::RequestMessage::Assistant { content, .. } => {
- content.map(|c| match c {
- open_ai::MessageContent::Plain(text) => text,
- open_ai::MessageContent::Multipart(parts) => parts
- .into_iter()
- .filter_map(|p| match p {
- open_ai::MessagePart::Text { text } => Some(text),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join(""),
- })
- }
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("")
- }))
- }
- }
+/// Run repair for a single example.
+pub async fn run_repair(
+ example: &mut Example,
+ args: &RepairArgs,
+ example_progress: &ExampleProgress,
+) -> Result<()> {
+ if has_successful_repair(example) {
+ return Ok(());
}
- async fn sync_batches(&self) -> Result<()> {
- match self {
- RepairClient::Anthropic(client) => client.sync_batches().await,
- RepairClient::OpenAi(client) => client.sync_batches().await,
- }
+ if !needs_repair(example, args.confidence_threshold) {
+ return Ok(());
}
-}
-/// Run the repair process on a set of examples.
-pub async fn run_repair(
- examples: &mut [Example],
- args: &RepairArgs,
- output_path: Option<&PathBuf>,
-) -> Result<()> {
+ run_parse_output(example).context("Failed to execute run_parse_output")?;
+
+ if example.prompt_inputs.is_none() {
+ anyhow::bail!("prompt_inputs missing (run context retrieval first)");
+ }
+
+ if example.predictions.is_empty() {
+ anyhow::bail!("no predictions available (run predict first)");
+ }
+
+ if example.qa.is_empty() {
+ anyhow::bail!("no QA results available (run qa first)");
+ }
+
+ let step_progress = example_progress.start(Step::Repair);
+
let model = model_for_backend(args.backend);
- let client = match args.backend {
+ let prompt = build_repair_prompt(example).context("Failed to build repair prompt")?;
+
+ step_progress.set_substatus("generating");
+
+ let response = match args.backend {
BatchProvider::Anthropic => {
- if args.no_batch {
- RepairClient::Anthropic(AnthropicClient::plain()?)
+ let client = if args.no_batch {
+ ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
+ AnthropicClient::plain().expect("Failed to create Anthropic client")
+ })
} else {
- RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
- }
+ ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
+ AnthropicClient::batch(&LLM_CACHE_DB)
+ .expect("Failed to create Anthropic client")
+ })
+ };
+
+ let messages = vec![anthropic::Message {
+ role: anthropic::Role::User,
+ content: vec![anthropic::RequestContent::Text {
+ text: prompt,
+ cache_control: None,
+ }],
+ }];
+
+ let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
+ return Ok(());
+ };
+
+ response
+ .content
+ .iter()
+ .filter_map(|c| match c {
+ anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
}
BatchProvider::Openai => {
- if args.no_batch {
- RepairClient::OpenAi(OpenAiClient::plain()?)
+ let client = if args.no_batch {
+ OPENAI_CLIENT_PLAIN
+ .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
} else {
- RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
- }
+ OPENAI_CLIENT_BATCH.get_or_init(|| {
+ OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
+ })
+ };
+
+ let messages = vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }];
+
+ let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
+ return Ok(());
+ };
+
+ response
+ .choices
+ .into_iter()
+ .filter_map(|choice| match choice.message {
+ open_ai::RequestMessage::Assistant { content, .. } => {
+ content.map(|c| match c {
+ open_ai::MessageContent::Plain(text) => text,
+ open_ai::MessageContent::Multipart(parts) => parts
+ .into_iter()
+ .filter_map(|p| match p {
+ open_ai::MessagePart::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join(""),
+ })
+ }
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
}
};
- eprintln!(
- "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}",
- model, args.backend, !args.no_batch, args.confidence_threshold
- );
-
- // First pass: identify examples that need repair and build prompts
- let mut repair_items: Vec<(usize, String)> = Vec::new();
- let mut skipped_missing_data = 0;
- let mut skipped_no_repair_needed = 0;
-
- for (idx, example) in examples.iter().enumerate() {
- // Skip if missing predictions or qa
- if example.predictions.is_empty() || example.qa.is_empty() {
- skipped_missing_data += 1;
- continue;
- }
-
- // Skip if doesn't need repair
- if !needs_repair(example, args.confidence_threshold) {
- skipped_no_repair_needed += 1;
- continue;
- }
-
- // Build repair prompt
- let Some(prompt) = build_repair_prompt(example) else {
- skipped_missing_data += 1;
- continue;
- };
+ let parse_result = TeacherPrompt::parse(example, &response);
+ let err = parse_result
+ .as_ref()
+ .err()
+ .map(|e| format!("Failed to parse repair response: {}", e));
- repair_items.push((idx, prompt));
- }
+ let (actual_patch, actual_cursor_offset) = parse_result.ok().unzip();
- eprintln!(
- "Skipping {} items with missing data, {} items that don't need repair",
- skipped_missing_data, skipped_no_repair_needed
- );
- eprintln!("{} items to repair", repair_items.len());
+ example.predictions.push(ExamplePrediction {
+ actual_patch,
+ actual_output: response,
+ actual_cursor_offset: actual_cursor_offset.flatten(),
+ error: err,
+ provider: PredictionProvider::Repair,
+ });
- // Process all items
- let mut results: Vec<(usize, Option<String>)> = Vec::new();
+ Ok(())
+}
+/// Sync batches for repair (upload pending requests, download finished results).
+pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
if args.no_batch {
- // Synchronous processing
- for (i, (idx, prompt)) in repair_items.iter().enumerate() {
- eprint!("\rProcessing {}/{}", i + 1, repair_items.len());
-
- let response = client.generate(model, 16384, prompt).await?;
- results.push((*idx, response));
- }
- eprintln!();
- } else {
- // Queue all for batching
- for (idx, prompt) in &repair_items {
- let response = client.generate(model, 16384, prompt).await?;
- results.push((*idx, response));
- }
-
- // Sync batches (upload pending, download finished)
- client.sync_batches().await?;
-
- if args.wait {
- eprintln!("Waiting for batch to complete...");
- loop {
- std::thread::sleep(std::time::Duration::from_secs(30));
- client.sync_batches().await?;
-
- // Re-check all items that didn't have results
- let mut all_done = true;
- for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() {
- if results[result_idx].1.is_none() {
- let response = client.generate(model, 16384, prompt).await?;
- if let Some(text) = response {
- results[result_idx] = (*idx, Some(text));
- } else {
- all_done = false;
- }
- }
- }
-
- let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
- if all_done {
- break;
- }
- eprintln!(
- "Still waiting... {}/{} results",
- done_count,
- repair_items.len()
- );
- }
- } else {
- let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
- if pending_count > 0 {
- eprintln!(
- "Batch submitted. {} pending. Run again later to retrieve results.",
- pending_count
- );
- }
- }
+ return Ok(());
}
- // Build results map by index
- let mut results_by_idx: std::collections::HashMap<usize, String> =
- std::collections::HashMap::new();
- for (idx, result) in results {
- if let Some(r) = result {
- results_by_idx.insert(idx, r);
+ match args.backend {
+ BatchProvider::Anthropic => {
+ let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
+ AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
+ });
+ client.sync_batches().await?;
}
- }
-
- // Output results
- let mut writer: Box<dyn Write> = if let Some(path) = output_path {
- Box::new(BufWriter::new(std::fs::File::create(path)?))
- } else {
- Box::new(std::io::stdout())
- };
-
- let mut num_repaired = 0;
- let mut num_repair_errors = 0;
-
- for (idx, example) in examples.iter_mut().enumerate() {
- // Add repair prediction if we have a result
- if let Some(response_text) = results_by_idx.get(&idx) {
- match parse_repair_response(example, response_text) {
- Ok(prediction) => {
- example.predictions.push(prediction);
- num_repaired += 1;
- }
- Err(e) => {
- // Add error prediction
- example.predictions.push(ExamplePrediction {
- actual_patch: None,
- actual_output: response_text.clone(),
- actual_cursor_offset: None,
- error: Some(format!("Failed to parse repair response: {}", e)),
- provider: PredictionProvider::Repair,
- });
- num_repair_errors += 1;
- }
- }
+ BatchProvider::Openai => {
+ let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
+ OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
+ });
+ client.sync_batches().await?;
}
-
- writeln!(writer, "{}", serde_json::to_string(&example)?)?;
}
- if let Some(path) = output_path {
- eprintln!("Results written to {}", path.display());
- }
-
- eprintln!("Repaired: {} items", num_repaired);
- eprintln!("Repair errors: {} items", num_repair_errors);
-
Ok(())
}