@@ -12,12 +12,14 @@ mod paths;
mod predict;
mod progress;
mod pull_examples;
+mod qa;
mod reorder_patch;
mod retrieve_context;
mod score;
mod split_commit;
mod split_dataset;
mod synthesize;
+mod word_diff;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use collections::HashSet;
use edit_prediction::EditPredictionStore;
@@ -159,6 +161,8 @@ enum Command {
FilterLanguages(FilterLanguagesArgs),
/// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
ImportBatch(ImportBatchArgs),
+ /// Assess the quality of predictions using LLM-as-a-judge
+ Qa(qa::QaArgs),
}
impl Display for Command {
@@ -194,6 +198,9 @@ impl Display for Command {
Command::ImportBatch(args) => {
write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
}
+ Command::Qa(_) => {
+ write!(f, "qa")
+ }
}
}
}
@@ -558,6 +565,32 @@ fn main() {
}
return;
}
+ Command::Qa(qa_args) => {
+ // Read examples from input files
+ let mut examples = example::read_example_files(&args.inputs);
+
+ // Apply filters
+ if let Some(name_filter) = &args.name {
+ examples.retain(|e| e.spec.name.contains(name_filter));
+ }
+ if let Some(repo_filter) = &args.repo {
+ examples.retain(|e| e.spec.repository_url.contains(repo_filter));
+ }
+ if let Some(offset) = args.offset {
+ examples.splice(0..offset, []);
+ }
+ if let Some(limit) = args.limit {
+ examples.truncate(limit);
+ }
+
+ smol::block_on(async {
+ if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
+ eprintln!("Error: {:?}", e);
+ std::process::exit(1);
+ }
+ });
+ return;
+ }
_ => {}
}
@@ -724,7 +757,8 @@ fn main() {
| Command::SplitCommit(_)
| Command::Split(_)
| Command::FilterLanguages(_)
- | Command::ImportBatch(_) => {
+ | Command::ImportBatch(_)
+ | Command::Qa(_) => {
unreachable!()
}
}
@@ -0,0 +1,395 @@
+//! Quality assessment of predictions using LLM-as-a-judge.
+//!
+//! This module uses the Anthropic Batch API to evaluate prediction quality.
+//! Caching is handled by the underlying AnthropicClient.
+
+use crate::anthropic_client::AnthropicClient;
+use crate::example::Example;
+use crate::paths::CACHE_DIR;
+use crate::word_diff::unified_to_word_diff;
+use anthropic::{Message, RequestContent, Role};
+use anyhow::Result;
+use serde::{Deserialize, Serialize};
+use std::io::{BufWriter, Write};
+use std::path::PathBuf;
+use std::sync::LazyLock;
+
+/// Model to use for QA evaluation.
+const MODEL: &str = "claude-sonnet-4-5";
+
+/// Path to the QA cache database.
+pub static QA_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("qa_cache.sqlite"));
+
+/// Arguments for the QA command.
+#[derive(Debug, Clone, clap::Args)]
+pub struct QaArgs {
+ /// Use synchronous API instead of batch
+ #[clap(long)]
+ pub no_batch: bool,
+
+ /// Wait for batch to complete (polls every 30s)
+ #[clap(long)]
+ pub wait: bool,
+}
+
+/// Result of QA evaluation for a single prediction.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct QaResult {
+ /// Free-form reasoning from the judge.
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub reasoning: Option<String>,
+
+ /// Does the prediction undo/revert changes the user intentionally made?
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub reverts_edits: Option<bool>,
+
+ /// Confidence score (1-5) for user acceptance likelihood.
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub confidence: Option<u8>,
+
+ /// The raw response from the model.
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub response: Option<String>,
+
+ /// Error message if parsing or request failed.
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub error: Option<String>,
+}
+
+/// Build the assessment prompt for an example.
+pub fn build_prompt(example: &Example) -> Option<String> {
+ let prediction = example.predictions.first()?;
+ let actual_patch = prediction.actual_patch.as_ref()?;
+ let prompt_inputs = example.prompt_inputs.as_ref()?;
+
+ let actual_patch_word_diff = unified_to_word_diff(actual_patch);
+
+ let mut edit_history = String::new();
+ for event in &prompt_inputs.edit_history {
+ match event.as_ref() {
+ zeta_prompt::Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ predicted: _,
+ in_open_source_repo: _,
+ } => {
+ edit_history.push_str(&format!("--- a{}\n", old_path.display()));
+ edit_history.push_str(&format!("+++ b{}\n", path.display()));
+ let diff_word_diff = unified_to_word_diff(diff);
+ edit_history.push_str(&diff_word_diff);
+ edit_history.push_str("\n\n");
+ }
+ }
+ }
+
+ Some(format!(
+ r#"
+You are evaluating an edit prediction model for a code editor. The model observes a programmer's recent edit history and predicts what edit they will make next.
+
+All diffs are in the word-diff format.
+
+The model is instructed to:
+- Complete partially-applied refactoring or changes
+- Maintain consistency with established patterns and style
+- NOT delete or revert text that was just added (unless the user explicitly undid it themselves)
+
+## Edit History (chronological)
+```````
+{edit_history}
+```````
+
+## Predicted Next Edit
+```````
+{actual_patch_word_diff}
+```````
+
+## Evaluate
+
+1. **reverts_edits**: Does the prediction undo, or revert changes the user intentionally made in the **edit history**?
+
+2. **confidence**: How likely is the user to accept this suggestion?
+ - 1 = Definitely reject (wrong, nonsensical, or harmful)
+ - 2 = Probably reject (doesn't fit intent or pattern)
+ - 3 = Uncertain (plausible but not clearly correct)
+ - 4 = Probably accept (reasonable next step)
+ - 5 = Definitely accept (obvious continuation)
+
+Output JSON in this format:
+
+```
+{{
+ "reasoning": "your reasoning here",
+ "reverts_edits": true/false,
+ "confidence": 1-5
+}}
+```
+"#
+ ))
+}
+
+/// Extract a code block from a response.
+fn extract_codeblock(response: &str) -> Option<String> {
+ let lines: Vec<&str> = response.lines().collect();
+ for (i, line) in lines.iter().enumerate() {
+ if line.starts_with("```") {
+ let start = i + 1;
+ for (j, end_line) in lines[start..].iter().enumerate() {
+ if end_line.starts_with("```") {
+ return Some(lines[start..start + j].join("\n"));
+ }
+ }
+ return Some(lines[start..].join("\n"));
+ }
+ }
+ None
+}
+
+/// Parse the LLM response into a QaResult.
+fn parse_response(response_text: &str) -> QaResult {
+ let codeblock = extract_codeblock(response_text);
+
+ // Try parsing codeblock first, then fall back to raw response
+ for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
+ let Some(text) = text_to_parse else {
+ continue;
+ };
+
+ if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
+ return QaResult {
+ reasoning: parsed
+ .get("reasoning")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string()),
+ reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
+ confidence: parsed
+ .get("confidence")
+ .and_then(|v| v.as_u64())
+ .map(|v| v as u8),
+ response: Some(response_text.to_string()),
+ error: None,
+ };
+ }
+ }
+
+ // If all parsing attempts fail, return error
+ QaResult {
+ reasoning: Some(response_text.to_string()),
+ reverts_edits: None,
+ confidence: None,
+ response: Some(response_text.to_string()),
+ error: Some("Could not parse JSON from response".to_string()),
+ }
+}
+
+/// Run the QA evaluation on a set of examples.
+pub async fn run_qa(
+ examples: &mut [Example],
+ args: &QaArgs,
+ output_path: Option<&PathBuf>,
+) -> Result<()> {
+ let client = if args.no_batch {
+ AnthropicClient::plain()?
+ } else {
+ AnthropicClient::batch(&QA_CACHE_DB)?
+ };
+
+ eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
+
+ // First pass: send requests (client handles caching internally)
+ let mut prompts: Vec<(usize, String)> = Vec::new();
+ let mut skipped_count = 0;
+
+ for (idx, example) in examples.iter().enumerate() {
+ let Some(prompt) = build_prompt(example) else {
+ skipped_count += 1;
+ continue;
+ };
+ prompts.push((idx, prompt));
+ }
+
+ if skipped_count > 0 {
+ eprintln!("Skipping {} items with missing actual_patch", skipped_count);
+ }
+
+ eprintln!("{} items to process", prompts.len());
+
+ // Process all items
+ let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
+
+ if args.no_batch {
+ // Synchronous processing
+ for (i, (idx, prompt)) in prompts.iter().enumerate() {
+ eprint!("\rProcessing {}/{}", i + 1, prompts.len());
+
+ let messages = vec![Message {
+ role: Role::User,
+ content: vec![RequestContent::Text {
+ text: prompt.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let response = client.generate(MODEL, 1024, messages).await?;
+ let result = response.map(|r| {
+ let text = r
+ .content
+ .iter()
+ .filter_map(|c| match c {
+ anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("");
+ parse_response(&text)
+ });
+ results.push((*idx, result));
+ }
+ eprintln!();
+ } else {
+ // Queue all for batching
+ for (idx, prompt) in &prompts {
+ let messages = vec![Message {
+ role: Role::User,
+ content: vec![RequestContent::Text {
+ text: prompt.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let response = client.generate(MODEL, 1024, messages).await?;
+ let result = response.map(|r| {
+ let text = r
+ .content
+ .iter()
+ .filter_map(|c| match c {
+ anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("");
+ parse_response(&text)
+ });
+ results.push((*idx, result));
+ }
+
+ // Sync batches (upload pending, download finished)
+ client.sync_batches().await?;
+
+ if args.wait {
+ eprintln!("Waiting for batch to complete...");
+ loop {
+ std::thread::sleep(std::time::Duration::from_secs(30));
+ client.sync_batches().await?;
+
+ // Re-check all items that didn't have results
+ let mut all_done = true;
+ for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
+ if results[result_idx].1.is_none() {
+ let messages = vec![Message {
+ role: Role::User,
+ content: vec![RequestContent::Text {
+ text: prompt.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let response = client.generate(MODEL, 1024, messages).await?;
+ if let Some(r) = response {
+ let text = r
+ .content
+ .iter()
+ .filter_map(|c| match c {
+ anthropic::ResponseContent::Text { text } => {
+ Some(text.as_str())
+ }
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("");
+ results[result_idx] = (*idx, Some(parse_response(&text)));
+ } else {
+ all_done = false;
+ }
+ }
+ }
+
+ let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
+ if all_done {
+ break;
+ }
+ eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
+ }
+ } else {
+ let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
+ if pending_count > 0 {
+ eprintln!(
+ "Batch submitted. {} pending. Run again later to retrieve results.",
+ pending_count
+ );
+ }
+ }
+ }
+
+ // Build results map by index
+ let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
+ std::collections::HashMap::new();
+ for (idx, result) in results {
+ if let Some(r) = result {
+ results_by_idx.insert(idx, r);
+ }
+ }
+
+ // Output results
+ let mut writer: Box<dyn Write> = if let Some(path) = output_path {
+ Box::new(BufWriter::new(std::fs::File::create(path)?))
+ } else {
+ Box::new(std::io::stdout())
+ };
+
+ let mut num_total = 0;
+ let mut num_reverts_edits = 0;
+
+ for (idx, example) in examples.iter_mut().enumerate() {
+ // Skip examples that couldn't be processed
+ if build_prompt(example).is_none() {
+ continue;
+ }
+
+ let result = results_by_idx
+ .get(&idx)
+ .cloned()
+ .unwrap_or_else(|| QaResult {
+ reasoning: None,
+ reverts_edits: None,
+ confidence: None,
+ response: None,
+ error: Some("Result not found".to_string()),
+ });
+
+ if result.reverts_edits == Some(true) {
+ num_reverts_edits += 1;
+ }
+ num_total += 1;
+
+ // Add QA result to example and output
+ let mut example_json = serde_json::to_value(&example)?;
+ example_json["qa"] = serde_json::to_value(&result)?;
+ writeln!(writer, "{}", serde_json::to_string(&example_json)?)?;
+ }
+
+ if let Some(path) = output_path {
+ eprintln!("Results written to {}", path.display());
+ }
+
+ eprintln!("Processed: {} items", num_total);
+ if num_total > 0 {
+ eprintln!(
+ "Reverts edits: {} ({:.2}%)",
+ num_reverts_edits,
+ num_reverts_edits as f64 / num_total as f64 * 100.0
+ );
+ }
+
+ Ok(())
+}
@@ -0,0 +1,343 @@
+//! Word-diff utilities for converting unified diffs to word-diff format.
+
+/// Convert unified diff to word-diff format.
+///
+/// This transforms line-based diffs into word-level diffs where:
+/// - Deletions are marked with `[-...-]`
+/// - Insertions are marked with `{+...+}`
+pub fn unified_to_word_diff(unified_diff: &str) -> String {
+ let lines: Vec<&str> = unified_diff.lines().collect();
+ let mut result = String::new();
+ let mut old_lines: Vec<&str> = Vec::new();
+ let mut new_lines: Vec<&str> = Vec::new();
+
+ let flush_changes =
+ |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, result: &mut String| {
+ if old_lines.is_empty() && new_lines.is_empty() {
+ return;
+ }
+
+ // Strip the leading '-' or '+' from each line
+ let old_text: String = old_lines
+ .iter()
+ .map(|line| if line.len() > 1 { &line[1..] } else { "" })
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ let new_text: String = new_lines
+ .iter()
+ .map(|line| if line.len() > 1 { &line[1..] } else { "" })
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ if !old_text.is_empty() || !new_text.is_empty() {
+ let word_diff = compute_word_diff(&old_text, &new_text);
+ result.push_str(&word_diff);
+ }
+
+ old_lines.clear();
+ new_lines.clear();
+ };
+
+ for line in lines {
+ if line.starts_with("---") || line.starts_with("+++") {
+ flush_changes(&mut old_lines, &mut new_lines, &mut result);
+ result.push_str(line);
+ result.push('\n');
+ } else if line.starts_with("@@") {
+ flush_changes(&mut old_lines, &mut new_lines, &mut result);
+ result.push_str(line);
+ result.push('\n');
+ } else if line.starts_with('-') {
+ old_lines.push(line);
+ } else if line.starts_with('+') {
+ new_lines.push(line);
+ } else if line.starts_with(' ') || line.is_empty() {
+ flush_changes(&mut old_lines, &mut new_lines, &mut result);
+ result.push_str(line);
+ result.push('\n');
+ } else {
+ // Header lines (diff --git, index, etc.)
+ flush_changes(&mut old_lines, &mut new_lines, &mut result);
+ result.push_str(line);
+ result.push('\n');
+ }
+ }
+
+ flush_changes(&mut old_lines, &mut new_lines, &mut result);
+ result
+}
+
+/// Compute word-level diff between two text blocks.
+///
+/// Words and whitespace are treated as separate tokens. The output uses:
+/// - `[-...-]` for deleted content
+/// - `{+...+}` for inserted content
+fn compute_word_diff(old_text: &str, new_text: &str) -> String {
+ // Split into words while preserving whitespace
+ let old_words = tokenize(old_text);
+ let new_words = tokenize(new_text);
+
+ let ops = diff_tokens(&old_words, &new_words);
+ let mut result = String::new();
+
+ for op in ops {
+ match op {
+ DiffOp::Equal(start, end) => {
+ for token in &old_words[start..end] {
+ result.push_str(token);
+ }
+ }
+ DiffOp::Delete(start, end) => {
+ result.push_str("[-");
+ for token in &old_words[start..end] {
+ result.push_str(token);
+ }
+ result.push_str("-]");
+ }
+ DiffOp::Insert(start, end) => {
+ result.push_str("{+");
+ for token in &new_words[start..end] {
+ result.push_str(token);
+ }
+ result.push_str("+}");
+ }
+ DiffOp::Replace {
+ old_start,
+ old_end,
+ new_start,
+ new_end,
+ } => {
+ result.push_str("[-");
+ for token in &old_words[old_start..old_end] {
+ result.push_str(token);
+ }
+ result.push_str("-]");
+ result.push_str("{+");
+ for token in &new_words[new_start..new_end] {
+ result.push_str(token);
+ }
+ result.push_str("+}");
+ }
+ }
+ }
+
+ if !result.is_empty() && !result.ends_with('\n') {
+ result.push('\n');
+ }
+
+ result
+}
+
+/// Tokenize text into words and whitespace sequences.
+fn tokenize(text: &str) -> Vec<&str> {
+ let mut tokens = Vec::new();
+ let mut chars = text.char_indices().peekable();
+
+ while let Some((start, ch)) = chars.next() {
+ if ch.is_whitespace() {
+ // Collect contiguous whitespace
+ let mut end = start + ch.len_utf8();
+ while let Some(&(_, next_ch)) = chars.peek() {
+ if next_ch.is_whitespace() {
+ end += next_ch.len_utf8();
+ chars.next();
+ } else {
+ break;
+ }
+ }
+ tokens.push(&text[start..end]);
+ } else {
+ // Collect contiguous non-whitespace
+ let mut end = start + ch.len_utf8();
+ while let Some(&(_, next_ch)) = chars.peek() {
+ if !next_ch.is_whitespace() {
+ end += next_ch.len_utf8();
+ chars.next();
+ } else {
+ break;
+ }
+ }
+ tokens.push(&text[start..end]);
+ }
+ }
+
+ tokens
+}
+
+#[derive(Debug)]
+enum DiffOp {
+ Equal(usize, usize),
+ Delete(usize, usize),
+ Insert(usize, usize),
+ Replace {
+ old_start: usize,
+ old_end: usize,
+ new_start: usize,
+ new_end: usize,
+ },
+}
+
+/// Compute diff operations between two token sequences using a simple LCS-based algorithm.
+fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec<DiffOp> {
+ // Build LCS table
+ let m = old.len();
+ let n = new.len();
+
+ if m == 0 && n == 0 {
+ return vec![];
+ }
+ if m == 0 {
+ return vec![DiffOp::Insert(0, n)];
+ }
+ if n == 0 {
+ return vec![DiffOp::Delete(0, m)];
+ }
+
+ // LCS dynamic programming
+ let mut dp = vec![vec![0usize; n + 1]; m + 1];
+ for i in 1..=m {
+ for j in 1..=n {
+ if old[i - 1] == new[j - 1] {
+ dp[i][j] = dp[i - 1][j - 1] + 1;
+ } else {
+ dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
+ }
+ }
+ }
+
+ // Backtrack to find operations
+ let mut ops = Vec::new();
+ let mut i = m;
+ let mut j = n;
+
+ // We'll collect in reverse order, then reverse at the end
+ let mut stack: Vec<(usize, usize, bool)> = Vec::new(); // (index, end, is_old)
+
+ while i > 0 || j > 0 {
+ if i > 0 && j > 0 && old[i - 1] == new[j - 1] {
+ stack.push((i - 1, i, true)); // Equal marker (using old index)
+ stack.push((j - 1, j, false)); // Paired with new index
+ i -= 1;
+ j -= 1;
+ } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
+ // Insert from new
+ stack.push((j - 1, j, false));
+ j -= 1;
+ } else {
+ // Delete from old
+ stack.push((i - 1, i, true));
+ i -= 1;
+ }
+ }
+
+ // Process the stack to build proper DiffOps
+ // This is a simplified approach - just iterate through and build ops
+ let mut old_idx = 0;
+ let mut new_idx = 0;
+
+ while old_idx < m || new_idx < n {
+ // Find next matching pair
+ let mut old_match = None;
+ let mut new_match = None;
+
+ for oi in old_idx..m {
+ for ni in new_idx..n {
+ if old[oi] == new[ni] {
+ old_match = Some(oi);
+ new_match = Some(ni);
+ break;
+ }
+ }
+ if old_match.is_some() {
+ break;
+ }
+ }
+
+ match (old_match, new_match) {
+ (Some(om), Some(nm)) => {
+ // Handle any deletions/insertions before the match
+ if old_idx < om && new_idx < nm {
+ ops.push(DiffOp::Replace {
+ old_start: old_idx,
+ old_end: om,
+ new_start: new_idx,
+ new_end: nm,
+ });
+ } else if old_idx < om {
+ ops.push(DiffOp::Delete(old_idx, om));
+ } else if new_idx < nm {
+ ops.push(DiffOp::Insert(new_idx, nm));
+ }
+
+ // Find the extent of the equal sequence
+ let mut eq_end_old = om;
+ let mut eq_end_new = nm;
+ while eq_end_old < m && eq_end_new < n && old[eq_end_old] == new[eq_end_new] {
+ eq_end_old += 1;
+ eq_end_new += 1;
+ }
+
+ ops.push(DiffOp::Equal(om, eq_end_old));
+ old_idx = eq_end_old;
+ new_idx = eq_end_new;
+ }
+ _ => {
+ // No more matches, handle remaining
+ if old_idx < m && new_idx < n {
+ ops.push(DiffOp::Replace {
+ old_start: old_idx,
+ old_end: m,
+ new_start: new_idx,
+ new_end: n,
+ });
+ } else if old_idx < m {
+ ops.push(DiffOp::Delete(old_idx, m));
+ } else if new_idx < n {
+ ops.push(DiffOp::Insert(new_idx, n));
+ }
+ break;
+ }
+ }
+ }
+
+ ops
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_tokenize() {
+ let tokens = tokenize("hello world");
+ assert_eq!(tokens, vec!["hello", " ", "world"]);
+
+ let tokens = tokenize(" multiple spaces ");
+ assert_eq!(tokens, vec![" ", "multiple", " ", "spaces", " "]);
+ }
+
+ #[test]
+ fn test_compute_word_diff_simple() {
+ let result = compute_word_diff("hello world", "hello there");
+ assert!(result.contains("[-world-]"));
+ assert!(result.contains("{+there+}"));
+ }
+
+ #[test]
+ fn test_unified_to_word_diff() {
+ let unified = "\
+--- a/file.txt
++++ b/file.txt
+@@ -1,3 +1,3 @@
+ context line
+-old text here
++new text here
+ more context";
+
+ let result = unified_to_word_diff(unified);
+ assert!(result.contains("--- a/file.txt"));
+ assert!(result.contains("+++ b/file.txt"));
+ assert!(result.contains("@@"));
+ }
+}