qa.rs

  1//! Quality assessment of predictions using LLM-as-a-judge.
  2//!
  3//! This module uses the Anthropic Batch API to evaluate prediction quality.
  4//! Caching is handled by the underlying AnthropicClient.
  5
  6use crate::anthropic_client::AnthropicClient;
  7use crate::example::Example;
  8use crate::format_prompt::extract_cursor_excerpt_from_example;
  9use crate::paths::LLM_CACHE_DB;
 10use crate::word_diff::unified_to_word_diff;
 11use anthropic::{Message, RequestContent, Role};
 12use anyhow::Result;
 13use serde::{Deserialize, Serialize};
 14use std::io::{BufWriter, Write};
 15use std::path::PathBuf;
 16
 17/// Model to use for QA evaluation.
 18const MODEL: &str = "claude-sonnet-4-5";
 19
 20const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
 21
 22/// Arguments for the QA command.
 23#[derive(Debug, Clone, clap::Args)]
 24pub struct QaArgs {
 25    /// Use synchronous API instead of batch
 26    #[clap(long)]
 27    pub no_batch: bool,
 28
 29    /// Wait for batch to complete (polls every 30s)
 30    #[clap(long)]
 31    pub wait: bool,
 32}
 33
 34/// Result of QA evaluation for a single prediction.
 35#[derive(Debug, Clone, Serialize, Deserialize)]
 36pub struct QaResult {
 37    /// Free-form reasoning from the judge.
 38    #[serde(default, skip_serializing_if = "Option::is_none")]
 39    pub reasoning: Option<String>,
 40
 41    /// Does the prediction undo/revert changes the user intentionally made?
 42    #[serde(default, skip_serializing_if = "Option::is_none")]
 43    pub reverts_edits: Option<bool>,
 44
 45    /// Confidence score (1-5) for user acceptance likelihood.
 46    #[serde(default, skip_serializing_if = "Option::is_none")]
 47    pub confidence: Option<u8>,
 48
 49    /// The raw response from the model.
 50    #[serde(default, skip_serializing_if = "Option::is_none")]
 51    pub response: Option<String>,
 52
 53    /// Error message if parsing or request failed.
 54    #[serde(default, skip_serializing_if = "Option::is_none")]
 55    pub error: Option<String>,
 56}
 57
 58/// Build the assessment prompt for an example.
 59pub fn build_prompt(example: &Example) -> Option<String> {
 60    let prediction = example.predictions.first()?;
 61    let actual_patch = prediction.actual_patch.as_ref()?;
 62    let prompt_inputs = example.prompt_inputs.as_ref()?;
 63
 64    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 65
 66    // Format cursor excerpt (reuse from format_prompt)
 67    let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
 68
 69    let mut edit_history = String::new();
 70    for event in &prompt_inputs.edit_history {
 71        match event.as_ref() {
 72            zeta_prompt::Event::BufferChange {
 73                path,
 74                old_path,
 75                diff,
 76                predicted: _,
 77                in_open_source_repo: _,
 78            } => {
 79                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 80                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 81                let diff_word_diff = unified_to_word_diff(diff);
 82                edit_history.push_str(&diff_word_diff);
 83                edit_history.push_str("\n\n");
 84            }
 85        }
 86    }
 87
 88    Some(
 89        PROMPT_TEMPLATE
 90            .replace("{edit_history}", &edit_history)
 91            .replace("{cursor_excerpt}", &cursor_excerpt)
 92            .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
 93    )
 94}
 95
 96/// Extract a code block from a response.
 97fn extract_codeblock(response: &str) -> Option<String> {
 98    let lines: Vec<&str> = response.lines().collect();
 99    for (i, line) in lines.iter().enumerate() {
100        if line.starts_with("```") {
101            let start = i + 1;
102            for (j, end_line) in lines[start..].iter().enumerate() {
103                if end_line.starts_with("```") {
104                    return Some(lines[start..start + j].join("\n"));
105                }
106            }
107            return Some(lines[start..].join("\n"));
108        }
109    }
110    None
111}
112
113/// Parse the LLM response into a QaResult.
114fn parse_response(response_text: &str) -> QaResult {
115    let codeblock = extract_codeblock(response_text);
116
117    // Try parsing codeblock first, then fall back to raw response
118    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
119        let Some(text) = text_to_parse else {
120            continue;
121        };
122
123        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
124            return QaResult {
125                reasoning: parsed
126                    .get("reasoning")
127                    .and_then(|v| v.as_str())
128                    .map(|s| s.to_string()),
129                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
130                confidence: parsed
131                    .get("confidence")
132                    .and_then(|v| v.as_u64())
133                    .map(|v| v as u8),
134                response: Some(response_text.to_string()),
135                error: None,
136            };
137        }
138    }
139
140    // If all parsing attempts fail, return error
141    QaResult {
142        reasoning: Some(response_text.to_string()),
143        reverts_edits: None,
144        confidence: None,
145        response: Some(response_text.to_string()),
146        error: Some("Could not parse JSON from response".to_string()),
147    }
148}
149
150/// Run the QA evaluation on a set of examples.
151pub async fn run_qa(
152    examples: &mut [Example],
153    args: &QaArgs,
154    output_path: Option<&PathBuf>,
155) -> Result<()> {
156    let client = if args.no_batch {
157        AnthropicClient::plain()?
158    } else {
159        AnthropicClient::batch(&LLM_CACHE_DB)?
160    };
161
162    eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
163
164    // First pass: send requests (client handles caching internally)
165    let mut prompts: Vec<(usize, String)> = Vec::new();
166    let mut skipped_count = 0;
167
168    for (idx, example) in examples.iter().enumerate() {
169        let Some(prompt) = build_prompt(example) else {
170            skipped_count += 1;
171            continue;
172        };
173        prompts.push((idx, prompt));
174    }
175
176    if skipped_count > 0 {
177        eprintln!("Skipping {} items with missing actual_patch", skipped_count);
178    }
179
180    eprintln!("{} items to process", prompts.len());
181
182    // Process all items
183    let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
184
185    if args.no_batch {
186        // Synchronous processing
187        for (i, (idx, prompt)) in prompts.iter().enumerate() {
188            eprint!("\rProcessing {}/{}", i + 1, prompts.len());
189
190            let messages = vec![Message {
191                role: Role::User,
192                content: vec![RequestContent::Text {
193                    text: prompt.clone(),
194                    cache_control: None,
195                }],
196            }];
197
198            let response = client.generate(MODEL, 1024, messages).await?;
199            let result = response.map(|r| {
200                let text = r
201                    .content
202                    .iter()
203                    .filter_map(|c| match c {
204                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
205                        _ => None,
206                    })
207                    .collect::<Vec<_>>()
208                    .join("");
209                parse_response(&text)
210            });
211            results.push((*idx, result));
212        }
213        eprintln!();
214    } else {
215        // Queue all for batching
216        for (idx, prompt) in &prompts {
217            let messages = vec![Message {
218                role: Role::User,
219                content: vec![RequestContent::Text {
220                    text: prompt.clone(),
221                    cache_control: None,
222                }],
223            }];
224
225            let response = client.generate(MODEL, 1024, messages).await?;
226            let result = response.map(|r| {
227                let text = r
228                    .content
229                    .iter()
230                    .filter_map(|c| match c {
231                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
232                        _ => None,
233                    })
234                    .collect::<Vec<_>>()
235                    .join("");
236                parse_response(&text)
237            });
238            results.push((*idx, result));
239        }
240
241        // Sync batches (upload pending, download finished)
242        client.sync_batches().await?;
243
244        if args.wait {
245            eprintln!("Waiting for batch to complete...");
246            loop {
247                std::thread::sleep(std::time::Duration::from_secs(30));
248                client.sync_batches().await?;
249
250                // Re-check all items that didn't have results
251                let mut all_done = true;
252                for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
253                    if results[result_idx].1.is_none() {
254                        let messages = vec![Message {
255                            role: Role::User,
256                            content: vec![RequestContent::Text {
257                                text: prompt.clone(),
258                                cache_control: None,
259                            }],
260                        }];
261
262                        let response = client.generate(MODEL, 1024, messages).await?;
263                        if let Some(r) = response {
264                            let text = r
265                                .content
266                                .iter()
267                                .filter_map(|c| match c {
268                                    anthropic::ResponseContent::Text { text } => {
269                                        Some(text.as_str())
270                                    }
271                                    _ => None,
272                                })
273                                .collect::<Vec<_>>()
274                                .join("");
275                            results[result_idx] = (*idx, Some(parse_response(&text)));
276                        } else {
277                            all_done = false;
278                        }
279                    }
280                }
281
282                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
283                if all_done {
284                    break;
285                }
286                eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
287            }
288        } else {
289            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
290            if pending_count > 0 {
291                eprintln!(
292                    "Batch submitted. {} pending. Run again later to retrieve results.",
293                    pending_count
294                );
295            }
296        }
297    }
298
299    // Build results map by index
300    let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
301        std::collections::HashMap::new();
302    for (idx, result) in results {
303        if let Some(r) = result {
304            results_by_idx.insert(idx, r);
305        }
306    }
307
308    // Output results
309    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
310        Box::new(BufWriter::new(std::fs::File::create(path)?))
311    } else {
312        Box::new(std::io::stdout())
313    };
314
315    let mut num_total = 0;
316    let mut num_reverts_edits = 0;
317
318    for (idx, example) in examples.iter_mut().enumerate() {
319        // Skip examples that couldn't be processed
320        if build_prompt(example).is_none() {
321            continue;
322        }
323
324        let result = results_by_idx.get(&idx).cloned();
325
326        if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
327            num_reverts_edits += 1;
328        }
329        num_total += 1;
330
331        // Populate QA results for each prediction (currently only first prediction is evaluated)
332        example.qa = example
333            .predictions
334            .iter()
335            .enumerate()
336            .map(|(i, _)| if i == 0 { result.clone() } else { None })
337            .collect();
338
339        writeln!(writer, "{}", serde_json::to_string(&example)?)?;
340    }
341
342    if let Some(path) = output_path {
343        eprintln!("Results written to {}", path.display());
344    }
345
346    eprintln!("Processed:     {} items", num_total);
347    if num_total > 0 {
348        eprintln!(
349            "Reverts edits: {} ({:.2}%)",
350            num_reverts_edits,
351            num_reverts_edits as f64 / num_total as f64 * 100.0
352        );
353    }
354
355    Ok(())
356}