1//! Quality assessment of predictions using LLM-as-a-judge.
2//!
3//! This module uses the Anthropic Batch API to evaluate prediction quality.
4//! Caching is handled by the underlying AnthropicClient.
5
6use crate::anthropic_client::AnthropicClient;
7use crate::example::Example;
8use crate::format_prompt::extract_cursor_excerpt_from_example;
9use crate::paths::LLM_CACHE_DB;
10use crate::word_diff::unified_to_word_diff;
11use anthropic::{Message, RequestContent, Role};
12use anyhow::Result;
13use serde::{Deserialize, Serialize};
14use std::io::{BufWriter, Write};
15use std::path::PathBuf;
16
17/// Model to use for QA evaluation.
18const MODEL: &str = "claude-sonnet-4-5";
19
20const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
21
22/// Arguments for the QA command.
23#[derive(Debug, Clone, clap::Args)]
24pub struct QaArgs {
25 /// Use synchronous API instead of batch
26 #[clap(long)]
27 pub no_batch: bool,
28
29 /// Wait for batch to complete (polls every 30s)
30 #[clap(long)]
31 pub wait: bool,
32}
33
34/// Result of QA evaluation for a single prediction.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct QaResult {
37 /// Free-form reasoning from the judge.
38 #[serde(default, skip_serializing_if = "Option::is_none")]
39 pub reasoning: Option<String>,
40
41 /// Does the prediction undo/revert changes the user intentionally made?
42 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub reverts_edits: Option<bool>,
44
45 /// Confidence score (1-5) for user acceptance likelihood.
46 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub confidence: Option<u8>,
48
49 /// The raw response from the model.
50 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub response: Option<String>,
52
53 /// Error message if parsing or request failed.
54 #[serde(default, skip_serializing_if = "Option::is_none")]
55 pub error: Option<String>,
56}
57
58/// Build the assessment prompt for an example.
59pub fn build_prompt(example: &Example) -> Option<String> {
60 let prediction = example.predictions.first()?;
61 let actual_patch = prediction.actual_patch.as_ref()?;
62 let prompt_inputs = example.prompt_inputs.as_ref()?;
63
64 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
65
66 // Format cursor excerpt (reuse from format_prompt)
67 let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
68
69 let mut edit_history = String::new();
70 for event in &prompt_inputs.edit_history {
71 match event.as_ref() {
72 zeta_prompt::Event::BufferChange {
73 path,
74 old_path,
75 diff,
76 predicted: _,
77 in_open_source_repo: _,
78 } => {
79 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
80 edit_history.push_str(&format!("+++ b{}\n", path.display()));
81 let diff_word_diff = unified_to_word_diff(diff);
82 edit_history.push_str(&diff_word_diff);
83 edit_history.push_str("\n\n");
84 }
85 }
86 }
87
88 Some(
89 PROMPT_TEMPLATE
90 .replace("{edit_history}", &edit_history)
91 .replace("{cursor_excerpt}", &cursor_excerpt)
92 .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
93 )
94}
95
96/// Extract a code block from a response.
97fn extract_codeblock(response: &str) -> Option<String> {
98 let lines: Vec<&str> = response.lines().collect();
99 for (i, line) in lines.iter().enumerate() {
100 if line.starts_with("```") {
101 let start = i + 1;
102 for (j, end_line) in lines[start..].iter().enumerate() {
103 if end_line.starts_with("```") {
104 return Some(lines[start..start + j].join("\n"));
105 }
106 }
107 return Some(lines[start..].join("\n"));
108 }
109 }
110 None
111}
112
113/// Parse the LLM response into a QaResult.
114fn parse_response(response_text: &str) -> QaResult {
115 let codeblock = extract_codeblock(response_text);
116
117 // Try parsing codeblock first, then fall back to raw response
118 for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
119 let Some(text) = text_to_parse else {
120 continue;
121 };
122
123 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
124 return QaResult {
125 reasoning: parsed
126 .get("reasoning")
127 .and_then(|v| v.as_str())
128 .map(|s| s.to_string()),
129 reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
130 confidence: parsed
131 .get("confidence")
132 .and_then(|v| v.as_u64())
133 .map(|v| v as u8),
134 response: Some(response_text.to_string()),
135 error: None,
136 };
137 }
138 }
139
140 // If all parsing attempts fail, return error
141 QaResult {
142 reasoning: Some(response_text.to_string()),
143 reverts_edits: None,
144 confidence: None,
145 response: Some(response_text.to_string()),
146 error: Some("Could not parse JSON from response".to_string()),
147 }
148}
149
150/// Run the QA evaluation on a set of examples.
151pub async fn run_qa(
152 examples: &mut [Example],
153 args: &QaArgs,
154 output_path: Option<&PathBuf>,
155) -> Result<()> {
156 let client = if args.no_batch {
157 AnthropicClient::plain()?
158 } else {
159 AnthropicClient::batch(&LLM_CACHE_DB)?
160 };
161
162 eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
163
164 // First pass: send requests (client handles caching internally)
165 let mut prompts: Vec<(usize, String)> = Vec::new();
166 let mut skipped_count = 0;
167
168 for (idx, example) in examples.iter().enumerate() {
169 let Some(prompt) = build_prompt(example) else {
170 skipped_count += 1;
171 continue;
172 };
173 prompts.push((idx, prompt));
174 }
175
176 if skipped_count > 0 {
177 eprintln!("Skipping {} items with missing actual_patch", skipped_count);
178 }
179
180 eprintln!("{} items to process", prompts.len());
181
182 // Process all items
183 let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
184
185 if args.no_batch {
186 // Synchronous processing
187 for (i, (idx, prompt)) in prompts.iter().enumerate() {
188 eprint!("\rProcessing {}/{}", i + 1, prompts.len());
189
190 let messages = vec![Message {
191 role: Role::User,
192 content: vec![RequestContent::Text {
193 text: prompt.clone(),
194 cache_control: None,
195 }],
196 }];
197
198 let response = client.generate(MODEL, 1024, messages).await?;
199 let result = response.map(|r| {
200 let text = r
201 .content
202 .iter()
203 .filter_map(|c| match c {
204 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
205 _ => None,
206 })
207 .collect::<Vec<_>>()
208 .join("");
209 parse_response(&text)
210 });
211 results.push((*idx, result));
212 }
213 eprintln!();
214 } else {
215 // Queue all for batching
216 for (idx, prompt) in &prompts {
217 let messages = vec![Message {
218 role: Role::User,
219 content: vec![RequestContent::Text {
220 text: prompt.clone(),
221 cache_control: None,
222 }],
223 }];
224
225 let response = client.generate(MODEL, 1024, messages).await?;
226 let result = response.map(|r| {
227 let text = r
228 .content
229 .iter()
230 .filter_map(|c| match c {
231 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
232 _ => None,
233 })
234 .collect::<Vec<_>>()
235 .join("");
236 parse_response(&text)
237 });
238 results.push((*idx, result));
239 }
240
241 // Sync batches (upload pending, download finished)
242 client.sync_batches().await?;
243
244 if args.wait {
245 eprintln!("Waiting for batch to complete...");
246 loop {
247 std::thread::sleep(std::time::Duration::from_secs(30));
248 client.sync_batches().await?;
249
250 // Re-check all items that didn't have results
251 let mut all_done = true;
252 for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
253 if results[result_idx].1.is_none() {
254 let messages = vec![Message {
255 role: Role::User,
256 content: vec![RequestContent::Text {
257 text: prompt.clone(),
258 cache_control: None,
259 }],
260 }];
261
262 let response = client.generate(MODEL, 1024, messages).await?;
263 if let Some(r) = response {
264 let text = r
265 .content
266 .iter()
267 .filter_map(|c| match c {
268 anthropic::ResponseContent::Text { text } => {
269 Some(text.as_str())
270 }
271 _ => None,
272 })
273 .collect::<Vec<_>>()
274 .join("");
275 results[result_idx] = (*idx, Some(parse_response(&text)));
276 } else {
277 all_done = false;
278 }
279 }
280 }
281
282 let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
283 if all_done {
284 break;
285 }
286 eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
287 }
288 } else {
289 let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
290 if pending_count > 0 {
291 eprintln!(
292 "Batch submitted. {} pending. Run again later to retrieve results.",
293 pending_count
294 );
295 }
296 }
297 }
298
299 // Build results map by index
300 let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
301 std::collections::HashMap::new();
302 for (idx, result) in results {
303 if let Some(r) = result {
304 results_by_idx.insert(idx, r);
305 }
306 }
307
308 // Output results
309 let mut writer: Box<dyn Write> = if let Some(path) = output_path {
310 Box::new(BufWriter::new(std::fs::File::create(path)?))
311 } else {
312 Box::new(std::io::stdout())
313 };
314
315 let mut num_total = 0;
316 let mut num_reverts_edits = 0;
317
318 for (idx, example) in examples.iter_mut().enumerate() {
319 // Skip examples that couldn't be processed
320 if build_prompt(example).is_none() {
321 continue;
322 }
323
324 let result = results_by_idx.get(&idx).cloned();
325
326 if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
327 num_reverts_edits += 1;
328 }
329 num_total += 1;
330
331 // Populate QA results for each prediction (currently only first prediction is evaluated)
332 example.qa = example
333 .predictions
334 .iter()
335 .enumerate()
336 .map(|(i, _)| if i == 0 { result.clone() } else { None })
337 .collect();
338
339 writeln!(writer, "{}", serde_json::to_string(&example)?)?;
340 }
341
342 if let Some(path) = output_path {
343 eprintln!("Results written to {}", path.display());
344 }
345
346 eprintln!("Processed: {} items", num_total);
347 if num_total > 0 {
348 eprintln!(
349 "Reverts edits: {} ({:.2}%)",
350 num_reverts_edits,
351 num_reverts_edits as f64 / num_total as f64 * 100.0
352 );
353 }
354
355 Ok(())
356}