1//! Quality assessment of predictions using LLM-as-a-judge.
2//!
3//! This module uses LLM Batch APIs to evaluate prediction quality.
4//! Caching is handled by the underlying client.
5
6use crate::BatchProvider;
7use crate::anthropic_client::AnthropicClient;
8use crate::example::Example;
9use crate::format_prompt::extract_cursor_excerpt_from_example;
10use crate::openai_client::OpenAiClient;
11use crate::paths::LLM_CACHE_DB;
12use crate::word_diff::unified_to_word_diff;
13use anyhow::Result;
14use serde::{Deserialize, Serialize};
15use std::io::{BufWriter, Write};
16use std::path::PathBuf;
17
18/// Arguments for the QA command.
19#[derive(Debug, Clone, clap::Args)]
20pub struct QaArgs {
21 /// Use synchronous API instead of batch
22 #[clap(long)]
23 pub no_batch: bool,
24
25 /// Wait for batch to complete (polls every 30s)
26 #[clap(long)]
27 pub wait: bool,
28
29 /// Which LLM provider to use (anthropic or openai)
30 #[clap(long, default_value = "openai")]
31 pub backend: BatchProvider,
32}
33
34fn model_for_backend(backend: BatchProvider) -> &'static str {
35 match backend {
36 BatchProvider::Anthropic => "claude-sonnet-4-5",
37 BatchProvider::Openai => "gpt-5.2",
38 }
39}
40
41/// Result of QA evaluation for a single prediction.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct QaResult {
44 /// Free-form reasoning from the judge.
45 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub reasoning: Option<String>,
47
48 /// Does the prediction undo/revert changes the user intentionally made?
49 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub reverts_edits: Option<bool>,
51
52 /// Confidence score (1-5) for user acceptance likelihood.
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub confidence: Option<u8>,
55
56 /// The raw response from the model.
57 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub response: Option<String>,
59
60 /// Error message if parsing or request failed.
61 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub error: Option<String>,
63}
64
65/// Build the assessment prompt for an example.
66pub fn build_prompt(example: &Example) -> Option<String> {
67 let prediction = example.predictions.first()?;
68 let actual_patch = prediction.actual_patch.as_ref()?;
69 let prompt_inputs = example.prompt_inputs.as_ref()?;
70
71 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
72
73 // Format cursor excerpt (reuse from format_prompt)
74 let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
75
76 let mut edit_history = String::new();
77 for event in &prompt_inputs.edit_history {
78 match event.as_ref() {
79 zeta_prompt::Event::BufferChange {
80 path,
81 old_path,
82 diff,
83 predicted: _,
84 in_open_source_repo: _,
85 } => {
86 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
87 edit_history.push_str(&format!("+++ b{}\n", path.display()));
88 let diff_word_diff = unified_to_word_diff(diff);
89 edit_history.push_str(&diff_word_diff);
90 edit_history.push_str("\n\n");
91 }
92 }
93 }
94
95 let prompt_template = crate::prompt_assets::get_prompt("qa.md");
96 Some(
97 prompt_template
98 .replace("{edit_history}", &edit_history)
99 .replace("{cursor_excerpt}", &cursor_excerpt)
100 .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
101 )
102}
103
104/// Extract a code block from a response.
105fn extract_codeblock(response: &str) -> Option<String> {
106 let lines: Vec<&str> = response.lines().collect();
107 for (i, line) in lines.iter().enumerate() {
108 if line.starts_with("```") {
109 let start = i + 1;
110 for (j, end_line) in lines[start..].iter().enumerate() {
111 if end_line.starts_with("```") {
112 return Some(lines[start..start + j].join("\n"));
113 }
114 }
115 return Some(lines[start..].join("\n"));
116 }
117 }
118 None
119}
120
121/// Parse the LLM response into a QaResult.
122fn parse_response(response_text: &str) -> QaResult {
123 let codeblock = extract_codeblock(response_text);
124
125 // Try parsing codeblock first, then fall back to raw response
126 for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
127 let Some(text) = text_to_parse else {
128 continue;
129 };
130
131 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
132 return QaResult {
133 reasoning: parsed
134 .get("reasoning")
135 .and_then(|v| v.as_str())
136 .map(|s| s.to_string()),
137 reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
138 confidence: parsed
139 .get("confidence")
140 .and_then(|v| v.as_u64())
141 .map(|v| v as u8),
142 response: Some(response_text.to_string()),
143 error: None,
144 };
145 }
146 }
147
148 // If all parsing attempts fail, return error
149 QaResult {
150 reasoning: Some(response_text.to_string()),
151 reverts_edits: None,
152 confidence: None,
153 response: Some(response_text.to_string()),
154 error: Some("Could not parse JSON from response".to_string()),
155 }
156}
157
158enum QaClient {
159 Anthropic(AnthropicClient),
160 OpenAi(OpenAiClient),
161}
162
163impl QaClient {
164 async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
165 match self {
166 QaClient::Anthropic(client) => {
167 let messages = vec![anthropic::Message {
168 role: anthropic::Role::User,
169 content: vec![anthropic::RequestContent::Text {
170 text: prompt.to_string(),
171 cache_control: None,
172 }],
173 }];
174 let response = client
175 .generate(model, max_tokens, messages, None, false)
176 .await?;
177 Ok(response.map(|r| {
178 r.content
179 .iter()
180 .filter_map(|c| match c {
181 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
182 _ => None,
183 })
184 .collect::<Vec<_>>()
185 .join("")
186 }))
187 }
188 QaClient::OpenAi(client) => {
189 let messages = vec![open_ai::RequestMessage::User {
190 content: open_ai::MessageContent::Plain(prompt.to_string()),
191 }];
192 let response = client
193 .generate(model, max_tokens, messages, None, false)
194 .await?;
195 Ok(response.map(|r| {
196 r.choices
197 .into_iter()
198 .filter_map(|choice| match choice.message {
199 open_ai::RequestMessage::Assistant { content, .. } => {
200 content.map(|c| match c {
201 open_ai::MessageContent::Plain(text) => text,
202 open_ai::MessageContent::Multipart(parts) => parts
203 .into_iter()
204 .filter_map(|p| match p {
205 open_ai::MessagePart::Text { text } => Some(text),
206 _ => None,
207 })
208 .collect::<Vec<_>>()
209 .join(""),
210 })
211 }
212 _ => None,
213 })
214 .collect::<Vec<_>>()
215 .join("")
216 }))
217 }
218 }
219 }
220
221 async fn sync_batches(&self) -> Result<()> {
222 match self {
223 QaClient::Anthropic(client) => client.sync_batches().await,
224 QaClient::OpenAi(client) => client.sync_batches().await,
225 }
226 }
227}
228
229/// Run the QA evaluation on a set of examples.
230pub async fn run_qa(
231 examples: &mut [Example],
232 args: &QaArgs,
233 output_path: Option<&PathBuf>,
234) -> Result<()> {
235 let model = model_for_backend(args.backend);
236 let client = match args.backend {
237 BatchProvider::Anthropic => {
238 if args.no_batch {
239 QaClient::Anthropic(AnthropicClient::plain()?)
240 } else {
241 QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
242 }
243 }
244 BatchProvider::Openai => {
245 if args.no_batch {
246 QaClient::OpenAi(OpenAiClient::plain()?)
247 } else {
248 QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
249 }
250 }
251 };
252
253 eprintln!(
254 "Using model: {}, backend: {:?}, batching: {}",
255 model, args.backend, !args.no_batch
256 );
257
258 // First pass: send requests (client handles caching internally)
259 let mut prompts: Vec<(usize, String)> = Vec::new();
260 let mut skipped_count = 0;
261
262 for (idx, example) in examples.iter().enumerate() {
263 let Some(prompt) = build_prompt(example) else {
264 skipped_count += 1;
265 continue;
266 };
267 prompts.push((idx, prompt));
268 }
269
270 if skipped_count > 0 {
271 eprintln!("Skipping {} items with missing actual_patch", skipped_count);
272 }
273
274 eprintln!("{} items to process", prompts.len());
275
276 // Process all items
277 let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
278
279 if args.no_batch {
280 // Synchronous processing
281 for (i, (idx, prompt)) in prompts.iter().enumerate() {
282 eprint!("\rProcessing {}/{}", i + 1, prompts.len());
283
284 let response = client.generate(model, 1024, prompt).await?;
285 let result = response.map(|text| parse_response(&text));
286 results.push((*idx, result));
287 }
288 eprintln!();
289 } else {
290 // Queue all for batching
291 for (idx, prompt) in &prompts {
292 let response = client.generate(model, 1024, prompt).await?;
293 let result = response.map(|text| parse_response(&text));
294 results.push((*idx, result));
295 }
296
297 // Sync batches (upload pending, download finished)
298 client.sync_batches().await?;
299
300 if args.wait {
301 eprintln!("Waiting for batch to complete...");
302 loop {
303 std::thread::sleep(std::time::Duration::from_secs(30));
304 client.sync_batches().await?;
305
306 // Re-check all items that didn't have results
307 let mut all_done = true;
308 for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
309 if results[result_idx].1.is_none() {
310 let response = client.generate(model, 1024, prompt).await?;
311 if let Some(text) = response {
312 results[result_idx] = (*idx, Some(parse_response(&text)));
313 } else {
314 all_done = false;
315 }
316 }
317 }
318
319 let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
320 if all_done {
321 break;
322 }
323 eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
324 }
325 } else {
326 let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
327 if pending_count > 0 {
328 eprintln!(
329 "Batch submitted. {} pending. Run again later to retrieve results.",
330 pending_count
331 );
332 }
333 }
334 }
335
336 // Build results map by index
337 let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
338 std::collections::HashMap::new();
339 for (idx, result) in results {
340 if let Some(r) = result {
341 results_by_idx.insert(idx, r);
342 }
343 }
344
345 // Output results
346 let mut writer: Box<dyn Write> = if let Some(path) = output_path {
347 Box::new(BufWriter::new(std::fs::File::create(path)?))
348 } else {
349 Box::new(std::io::stdout())
350 };
351
352 let mut num_total = 0;
353 let mut num_reverts_edits = 0;
354
355 for (idx, example) in examples.iter_mut().enumerate() {
356 // Skip examples that couldn't be processed
357 if build_prompt(example).is_none() {
358 continue;
359 }
360
361 let result = results_by_idx.get(&idx).cloned();
362
363 if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
364 num_reverts_edits += 1;
365 }
366 num_total += 1;
367
368 // Populate QA results for each prediction (currently only first prediction is evaluated)
369 example.qa = example
370 .predictions
371 .iter()
372 .enumerate()
373 .map(|(i, _)| if i == 0 { result.clone() } else { None })
374 .collect();
375
376 writeln!(writer, "{}", serde_json::to_string(&example)?)?;
377 }
378
379 if let Some(path) = output_path {
380 eprintln!("Results written to {}", path.display());
381 }
382
383 eprintln!("Processed: {} items", num_total);
384 if num_total > 0 {
385 eprintln!(
386 "Reverts edits: {} ({:.2}%)",
387 num_reverts_edits,
388 num_reverts_edits as f64 / num_total as f64 * 100.0
389 );
390 }
391
392 Ok(())
393}