From 84b40d507f099b2f4940dacdf7036f0538027578 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Sat, 24 Jan 2026 01:28:12 +0200 Subject: [PATCH] ep: Add `qa` subcommand to check predictions quality (#47520) Release Notes: - N/A --- crates/edit_prediction_cli/src/main.rs | 36 +- crates/edit_prediction_cli/src/qa.rs | 395 ++++++++++++++++++++ crates/edit_prediction_cli/src/word_diff.rs | 343 +++++++++++++++++ 3 files changed, 773 insertions(+), 1 deletion(-) create mode 100644 crates/edit_prediction_cli/src/qa.rs create mode 100644 crates/edit_prediction_cli/src/word_diff.rs diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 8329e394bfa8fe47b9a40f327e7cb38b9839da6c..315e954c5f52d32c77e1f71ee4af2950ace6f83d 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -12,12 +12,14 @@ mod paths; mod predict; mod progress; mod pull_examples; +mod qa; mod reorder_patch; mod retrieve_context; mod score; mod split_commit; mod split_dataset; mod synthesize; +mod word_diff; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; use collections::HashSet; use edit_prediction::EditPredictionStore; @@ -159,6 +161,8 @@ enum Command { FilterLanguages(FilterLanguagesArgs), /// Import Anthropic batch results by batch IDs (useful for recovering after database loss) ImportBatch(ImportBatchArgs), + /// Assess the quality of predictions using LLM-as-a-judge + Qa(qa::QaArgs), } impl Display for Command { @@ -194,6 +198,9 @@ impl Display for Command { Command::ImportBatch(args) => { write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" ")) } + Command::Qa(_) => { + write!(f, "qa") + } } } } @@ -558,6 +565,32 @@ fn main() { } return; } + Command::Qa(qa_args) => { + // Read examples from input files + let mut examples = example::read_example_files(&args.inputs); + + // Apply filters + if let Some(name_filter) = &args.name { + examples.retain(|e| e.spec.name.contains(name_filter)); + } + if let Some(repo_filter) = &args.repo { + examples.retain(|e| e.spec.repository_url.contains(repo_filter)); + } + if let Some(offset) = args.offset { + examples.splice(0..offset, []); + } + if let Some(limit) = args.limit { + examples.truncate(limit); + } + + smol::block_on(async { + if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await { + eprintln!("Error: {:?}", e); + std::process::exit(1); + } + }); + return; + } _ => {} } @@ -724,7 +757,8 @@ fn main() { | Command::SplitCommit(_) | Command::Split(_) | Command::FilterLanguages(_) - | Command::ImportBatch(_) => { + | Command::ImportBatch(_) + | Command::Qa(_) => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs new file mode 100644 index 0000000000000000000000000000000000000000..6c30c6f49b85ae5bb7521db68bf99d932e413a3a --- /dev/null +++ b/crates/edit_prediction_cli/src/qa.rs @@ -0,0 +1,395 @@ +//! Quality assessment of predictions using LLM-as-a-judge. +//! +//! This module uses the Anthropic Batch API to evaluate prediction quality. +//! Caching is handled by the underlying AnthropicClient. + +use crate::anthropic_client::AnthropicClient; +use crate::example::Example; +use crate::paths::CACHE_DIR; +use crate::word_diff::unified_to_word_diff; +use anthropic::{Message, RequestContent, Role}; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::io::{BufWriter, Write}; +use std::path::PathBuf; +use std::sync::LazyLock; + +/// Model to use for QA evaluation. +const MODEL: &str = "claude-sonnet-4-5"; + +/// Path to the QA cache database. +pub static QA_CACHE_DB: LazyLock = LazyLock::new(|| CACHE_DIR.join("qa_cache.sqlite")); + +/// Arguments for the QA command. +#[derive(Debug, Clone, clap::Args)] +pub struct QaArgs { + /// Use synchronous API instead of batch + #[clap(long)] + pub no_batch: bool, + + /// Wait for batch to complete (polls every 30s) + #[clap(long)] + pub wait: bool, +} + +/// Result of QA evaluation for a single prediction. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QaResult { + /// Free-form reasoning from the judge. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + /// Does the prediction undo/revert changes the user intentionally made? + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reverts_edits: Option, + + /// Confidence score (1-5) for user acceptance likelihood. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub confidence: Option, + + /// The raw response from the model. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub response: Option, + + /// Error message if parsing or request failed. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// Build the assessment prompt for an example. +pub fn build_prompt(example: &Example) -> Option { + let prediction = example.predictions.first()?; + let actual_patch = prediction.actual_patch.as_ref()?; + let prompt_inputs = example.prompt_inputs.as_ref()?; + + let actual_patch_word_diff = unified_to_word_diff(actual_patch); + + let mut edit_history = String::new(); + for event in &prompt_inputs.edit_history { + match event.as_ref() { + zeta_prompt::Event::BufferChange { + path, + old_path, + diff, + predicted: _, + in_open_source_repo: _, + } => { + edit_history.push_str(&format!("--- a{}\n", old_path.display())); + edit_history.push_str(&format!("+++ b{}\n", path.display())); + let diff_word_diff = unified_to_word_diff(diff); + edit_history.push_str(&diff_word_diff); + edit_history.push_str("\n\n"); + } + } + } + + Some(format!( + r#" +You are evaluating an edit prediction model for a code editor. The model observes a programmer's recent edit history and predicts what edit they will make next. + +All diffs are in the word-diff format. + +The model is instructed to: +- Complete partially-applied refactoring or changes +- Maintain consistency with established patterns and style +- NOT delete or revert text that was just added (unless the user explicitly undid it themselves) + +## Edit History (chronological) +``````` +{edit_history} +``````` + +## Predicted Next Edit +``````` +{actual_patch_word_diff} +``````` + +## Evaluate + +1. **reverts_edits**: Does the prediction undo, or revert changes the user intentionally made in the **edit history**? + +2. **confidence**: How likely is the user to accept this suggestion? + - 1 = Definitely reject (wrong, nonsensical, or harmful) + - 2 = Probably reject (doesn't fit intent or pattern) + - 3 = Uncertain (plausible but not clearly correct) + - 4 = Probably accept (reasonable next step) + - 5 = Definitely accept (obvious continuation) + +Output JSON in this format: + +``` +{{ + "reasoning": "your reasoning here", + "reverts_edits": true/false, + "confidence": 1-5 +}} +``` +"# + )) +} + +/// Extract a code block from a response. +fn extract_codeblock(response: &str) -> Option { + let lines: Vec<&str> = response.lines().collect(); + for (i, line) in lines.iter().enumerate() { + if line.starts_with("```") { + let start = i + 1; + for (j, end_line) in lines[start..].iter().enumerate() { + if end_line.starts_with("```") { + return Some(lines[start..start + j].join("\n")); + } + } + return Some(lines[start..].join("\n")); + } + } + None +} + +/// Parse the LLM response into a QaResult. +fn parse_response(response_text: &str) -> QaResult { + let codeblock = extract_codeblock(response_text); + + // Try parsing codeblock first, then fall back to raw response + for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] { + let Some(text) = text_to_parse else { + continue; + }; + + if let Ok(parsed) = serde_json::from_str::(text) { + return QaResult { + reasoning: parsed + .get("reasoning") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()), + confidence: parsed + .get("confidence") + .and_then(|v| v.as_u64()) + .map(|v| v as u8), + response: Some(response_text.to_string()), + error: None, + }; + } + } + + // If all parsing attempts fail, return error + QaResult { + reasoning: Some(response_text.to_string()), + reverts_edits: None, + confidence: None, + response: Some(response_text.to_string()), + error: Some("Could not parse JSON from response".to_string()), + } +} + +/// Run the QA evaluation on a set of examples. +pub async fn run_qa( + examples: &mut [Example], + args: &QaArgs, + output_path: Option<&PathBuf>, +) -> Result<()> { + let client = if args.no_batch { + AnthropicClient::plain()? + } else { + AnthropicClient::batch(&QA_CACHE_DB)? + }; + + eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch); + + // First pass: send requests (client handles caching internally) + let mut prompts: Vec<(usize, String)> = Vec::new(); + let mut skipped_count = 0; + + for (idx, example) in examples.iter().enumerate() { + let Some(prompt) = build_prompt(example) else { + skipped_count += 1; + continue; + }; + prompts.push((idx, prompt)); + } + + if skipped_count > 0 { + eprintln!("Skipping {} items with missing actual_patch", skipped_count); + } + + eprintln!("{} items to process", prompts.len()); + + // Process all items + let mut results: Vec<(usize, Option)> = Vec::new(); + + if args.no_batch { + // Synchronous processing + for (i, (idx, prompt)) in prompts.iter().enumerate() { + eprint!("\rProcessing {}/{}", i + 1, prompts.len()); + + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let response = client.generate(MODEL, 1024, messages).await?; + let result = response.map(|r| { + let text = r + .content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join(""); + parse_response(&text) + }); + results.push((*idx, result)); + } + eprintln!(); + } else { + // Queue all for batching + for (idx, prompt) in &prompts { + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let response = client.generate(MODEL, 1024, messages).await?; + let result = response.map(|r| { + let text = r + .content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join(""); + parse_response(&text) + }); + results.push((*idx, result)); + } + + // Sync batches (upload pending, download finished) + client.sync_batches().await?; + + if args.wait { + eprintln!("Waiting for batch to complete..."); + loop { + std::thread::sleep(std::time::Duration::from_secs(30)); + client.sync_batches().await?; + + // Re-check all items that didn't have results + let mut all_done = true; + for (result_idx, (idx, prompt)) in prompts.iter().enumerate() { + if results[result_idx].1.is_none() { + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let response = client.generate(MODEL, 1024, messages).await?; + if let Some(r) = response { + let text = r + .content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => { + Some(text.as_str()) + } + _ => None, + }) + .collect::>() + .join(""); + results[result_idx] = (*idx, Some(parse_response(&text))); + } else { + all_done = false; + } + } + } + + let done_count = results.iter().filter(|(_, r)| r.is_some()).count(); + if all_done { + break; + } + eprintln!("Still waiting... {}/{} results", done_count, prompts.len()); + } + } else { + let pending_count = results.iter().filter(|(_, r)| r.is_none()).count(); + if pending_count > 0 { + eprintln!( + "Batch submitted. {} pending. Run again later to retrieve results.", + pending_count + ); + } + } + } + + // Build results map by index + let mut results_by_idx: std::collections::HashMap = + std::collections::HashMap::new(); + for (idx, result) in results { + if let Some(r) = result { + results_by_idx.insert(idx, r); + } + } + + // Output results + let mut writer: Box = if let Some(path) = output_path { + Box::new(BufWriter::new(std::fs::File::create(path)?)) + } else { + Box::new(std::io::stdout()) + }; + + let mut num_total = 0; + let mut num_reverts_edits = 0; + + for (idx, example) in examples.iter_mut().enumerate() { + // Skip examples that couldn't be processed + if build_prompt(example).is_none() { + continue; + } + + let result = results_by_idx + .get(&idx) + .cloned() + .unwrap_or_else(|| QaResult { + reasoning: None, + reverts_edits: None, + confidence: None, + response: None, + error: Some("Result not found".to_string()), + }); + + if result.reverts_edits == Some(true) { + num_reverts_edits += 1; + } + num_total += 1; + + // Add QA result to example and output + let mut example_json = serde_json::to_value(&example)?; + example_json["qa"] = serde_json::to_value(&result)?; + writeln!(writer, "{}", serde_json::to_string(&example_json)?)?; + } + + if let Some(path) = output_path { + eprintln!("Results written to {}", path.display()); + } + + eprintln!("Processed: {} items", num_total); + if num_total > 0 { + eprintln!( + "Reverts edits: {} ({:.2}%)", + num_reverts_edits, + num_reverts_edits as f64 / num_total as f64 * 100.0 + ); + } + + Ok(()) +} diff --git a/crates/edit_prediction_cli/src/word_diff.rs b/crates/edit_prediction_cli/src/word_diff.rs new file mode 100644 index 0000000000000000000000000000000000000000..b5db40d52e5b15e57cf3f6b92f4a1c0a0bbc13da --- /dev/null +++ b/crates/edit_prediction_cli/src/word_diff.rs @@ -0,0 +1,343 @@ +//! Word-diff utilities for converting unified diffs to word-diff format. + +/// Convert unified diff to word-diff format. +/// +/// This transforms line-based diffs into word-level diffs where: +/// - Deletions are marked with `[-...-]` +/// - Insertions are marked with `{+...+}` +pub fn unified_to_word_diff(unified_diff: &str) -> String { + let lines: Vec<&str> = unified_diff.lines().collect(); + let mut result = String::new(); + let mut old_lines: Vec<&str> = Vec::new(); + let mut new_lines: Vec<&str> = Vec::new(); + + let flush_changes = + |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, result: &mut String| { + if old_lines.is_empty() && new_lines.is_empty() { + return; + } + + // Strip the leading '-' or '+' from each line + let old_text: String = old_lines + .iter() + .map(|line| if line.len() > 1 { &line[1..] } else { "" }) + .collect::>() + .join("\n"); + + let new_text: String = new_lines + .iter() + .map(|line| if line.len() > 1 { &line[1..] } else { "" }) + .collect::>() + .join("\n"); + + if !old_text.is_empty() || !new_text.is_empty() { + let word_diff = compute_word_diff(&old_text, &new_text); + result.push_str(&word_diff); + } + + old_lines.clear(); + new_lines.clear(); + }; + + for line in lines { + if line.starts_with("---") || line.starts_with("+++") { + flush_changes(&mut old_lines, &mut new_lines, &mut result); + result.push_str(line); + result.push('\n'); + } else if line.starts_with("@@") { + flush_changes(&mut old_lines, &mut new_lines, &mut result); + result.push_str(line); + result.push('\n'); + } else if line.starts_with('-') { + old_lines.push(line); + } else if line.starts_with('+') { + new_lines.push(line); + } else if line.starts_with(' ') || line.is_empty() { + flush_changes(&mut old_lines, &mut new_lines, &mut result); + result.push_str(line); + result.push('\n'); + } else { + // Header lines (diff --git, index, etc.) + flush_changes(&mut old_lines, &mut new_lines, &mut result); + result.push_str(line); + result.push('\n'); + } + } + + flush_changes(&mut old_lines, &mut new_lines, &mut result); + result +} + +/// Compute word-level diff between two text blocks. +/// +/// Words and whitespace are treated as separate tokens. The output uses: +/// - `[-...-]` for deleted content +/// - `{+...+}` for inserted content +fn compute_word_diff(old_text: &str, new_text: &str) -> String { + // Split into words while preserving whitespace + let old_words = tokenize(old_text); + let new_words = tokenize(new_text); + + let ops = diff_tokens(&old_words, &new_words); + let mut result = String::new(); + + for op in ops { + match op { + DiffOp::Equal(start, end) => { + for token in &old_words[start..end] { + result.push_str(token); + } + } + DiffOp::Delete(start, end) => { + result.push_str("[-"); + for token in &old_words[start..end] { + result.push_str(token); + } + result.push_str("-]"); + } + DiffOp::Insert(start, end) => { + result.push_str("{+"); + for token in &new_words[start..end] { + result.push_str(token); + } + result.push_str("+}"); + } + DiffOp::Replace { + old_start, + old_end, + new_start, + new_end, + } => { + result.push_str("[-"); + for token in &old_words[old_start..old_end] { + result.push_str(token); + } + result.push_str("-]"); + result.push_str("{+"); + for token in &new_words[new_start..new_end] { + result.push_str(token); + } + result.push_str("+}"); + } + } + } + + if !result.is_empty() && !result.ends_with('\n') { + result.push('\n'); + } + + result +} + +/// Tokenize text into words and whitespace sequences. +fn tokenize(text: &str) -> Vec<&str> { + let mut tokens = Vec::new(); + let mut chars = text.char_indices().peekable(); + + while let Some((start, ch)) = chars.next() { + if ch.is_whitespace() { + // Collect contiguous whitespace + let mut end = start + ch.len_utf8(); + while let Some(&(_, next_ch)) = chars.peek() { + if next_ch.is_whitespace() { + end += next_ch.len_utf8(); + chars.next(); + } else { + break; + } + } + tokens.push(&text[start..end]); + } else { + // Collect contiguous non-whitespace + let mut end = start + ch.len_utf8(); + while let Some(&(_, next_ch)) = chars.peek() { + if !next_ch.is_whitespace() { + end += next_ch.len_utf8(); + chars.next(); + } else { + break; + } + } + tokens.push(&text[start..end]); + } + } + + tokens +} + +#[derive(Debug)] +enum DiffOp { + Equal(usize, usize), + Delete(usize, usize), + Insert(usize, usize), + Replace { + old_start: usize, + old_end: usize, + new_start: usize, + new_end: usize, + }, +} + +/// Compute diff operations between two token sequences using a simple LCS-based algorithm. +fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec { + // Build LCS table + let m = old.len(); + let n = new.len(); + + if m == 0 && n == 0 { + return vec![]; + } + if m == 0 { + return vec![DiffOp::Insert(0, n)]; + } + if n == 0 { + return vec![DiffOp::Delete(0, m)]; + } + + // LCS dynamic programming + let mut dp = vec![vec![0usize; n + 1]; m + 1]; + for i in 1..=m { + for j in 1..=n { + if old[i - 1] == new[j - 1] { + dp[i][j] = dp[i - 1][j - 1] + 1; + } else { + dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]); + } + } + } + + // Backtrack to find operations + let mut ops = Vec::new(); + let mut i = m; + let mut j = n; + + // We'll collect in reverse order, then reverse at the end + let mut stack: Vec<(usize, usize, bool)> = Vec::new(); // (index, end, is_old) + + while i > 0 || j > 0 { + if i > 0 && j > 0 && old[i - 1] == new[j - 1] { + stack.push((i - 1, i, true)); // Equal marker (using old index) + stack.push((j - 1, j, false)); // Paired with new index + i -= 1; + j -= 1; + } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) { + // Insert from new + stack.push((j - 1, j, false)); + j -= 1; + } else { + // Delete from old + stack.push((i - 1, i, true)); + i -= 1; + } + } + + // Process the stack to build proper DiffOps + // This is a simplified approach - just iterate through and build ops + let mut old_idx = 0; + let mut new_idx = 0; + + while old_idx < m || new_idx < n { + // Find next matching pair + let mut old_match = None; + let mut new_match = None; + + for oi in old_idx..m { + for ni in new_idx..n { + if old[oi] == new[ni] { + old_match = Some(oi); + new_match = Some(ni); + break; + } + } + if old_match.is_some() { + break; + } + } + + match (old_match, new_match) { + (Some(om), Some(nm)) => { + // Handle any deletions/insertions before the match + if old_idx < om && new_idx < nm { + ops.push(DiffOp::Replace { + old_start: old_idx, + old_end: om, + new_start: new_idx, + new_end: nm, + }); + } else if old_idx < om { + ops.push(DiffOp::Delete(old_idx, om)); + } else if new_idx < nm { + ops.push(DiffOp::Insert(new_idx, nm)); + } + + // Find the extent of the equal sequence + let mut eq_end_old = om; + let mut eq_end_new = nm; + while eq_end_old < m && eq_end_new < n && old[eq_end_old] == new[eq_end_new] { + eq_end_old += 1; + eq_end_new += 1; + } + + ops.push(DiffOp::Equal(om, eq_end_old)); + old_idx = eq_end_old; + new_idx = eq_end_new; + } + _ => { + // No more matches, handle remaining + if old_idx < m && new_idx < n { + ops.push(DiffOp::Replace { + old_start: old_idx, + old_end: m, + new_start: new_idx, + new_end: n, + }); + } else if old_idx < m { + ops.push(DiffOp::Delete(old_idx, m)); + } else if new_idx < n { + ops.push(DiffOp::Insert(new_idx, n)); + } + break; + } + } + } + + ops +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tokenize() { + let tokens = tokenize("hello world"); + assert_eq!(tokens, vec!["hello", " ", "world"]); + + let tokens = tokenize(" multiple spaces "); + assert_eq!(tokens, vec![" ", "multiple", " ", "spaces", " "]); + } + + #[test] + fn test_compute_word_diff_simple() { + let result = compute_word_diff("hello world", "hello there"); + assert!(result.contains("[-world-]")); + assert!(result.contains("{+there+}")); + } + + #[test] + fn test_unified_to_word_diff() { + let unified = "\ +--- a/file.txt ++++ b/file.txt +@@ -1,3 +1,3 @@ + context line +-old text here ++new text here + more context"; + + let result = unified_to_word_diff(unified); + assert!(result.contains("--- a/file.txt")); + assert!(result.contains("+++ b/file.txt")); + assert!(result.contains("@@")); + } +}