1//! Quality assessment of predictions using LLM-as-a-judge.
2//!
3//! This module uses LLM Batch APIs to evaluate prediction quality.
4//! Caching is handled by the underlying client.
5
6use crate::BatchProvider;
7use crate::anthropic_client::AnthropicClient;
8use crate::example::Example;
9use crate::format_prompt::extract_cursor_excerpt_from_example;
10use crate::openai_client::OpenAiClient;
11use crate::paths::LLM_CACHE_DB;
12use crate::word_diff::unified_to_word_diff;
13use anyhow::Result;
14use serde::{Deserialize, Serialize};
15use std::io::{BufWriter, Write};
16use std::path::PathBuf;
17
18const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
19
20/// Arguments for the QA command.
21#[derive(Debug, Clone, clap::Args)]
22pub struct QaArgs {
23 /// Use synchronous API instead of batch
24 #[clap(long)]
25 pub no_batch: bool,
26
27 /// Wait for batch to complete (polls every 30s)
28 #[clap(long)]
29 pub wait: bool,
30
31 /// Which LLM provider to use (anthropic or openai)
32 #[clap(long, default_value = "openai")]
33 pub backend: BatchProvider,
34}
35
36fn model_for_backend(backend: BatchProvider) -> &'static str {
37 match backend {
38 BatchProvider::Anthropic => "claude-sonnet-4-5",
39 BatchProvider::Openai => "gpt-5.2",
40 }
41}
42
43/// Result of QA evaluation for a single prediction.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct QaResult {
46 /// Free-form reasoning from the judge.
47 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub reasoning: Option<String>,
49
50 /// Does the prediction undo/revert changes the user intentionally made?
51 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub reverts_edits: Option<bool>,
53
54 /// Confidence score (1-5) for user acceptance likelihood.
55 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub confidence: Option<u8>,
57
58 /// The raw response from the model.
59 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub response: Option<String>,
61
62 /// Error message if parsing or request failed.
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub error: Option<String>,
65}
66
67/// Build the assessment prompt for an example.
68pub fn build_prompt(example: &Example) -> Option<String> {
69 let prediction = example.predictions.first()?;
70 let actual_patch = prediction.actual_patch.as_ref()?;
71 let prompt_inputs = example.prompt_inputs.as_ref()?;
72
73 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
74
75 // Format cursor excerpt (reuse from format_prompt)
76 let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
77
78 let mut edit_history = String::new();
79 for event in &prompt_inputs.edit_history {
80 match event.as_ref() {
81 zeta_prompt::Event::BufferChange {
82 path,
83 old_path,
84 diff,
85 predicted: _,
86 in_open_source_repo: _,
87 } => {
88 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
89 edit_history.push_str(&format!("+++ b{}\n", path.display()));
90 let diff_word_diff = unified_to_word_diff(diff);
91 edit_history.push_str(&diff_word_diff);
92 edit_history.push_str("\n\n");
93 }
94 }
95 }
96
97 Some(
98 PROMPT_TEMPLATE
99 .replace("{edit_history}", &edit_history)
100 .replace("{cursor_excerpt}", &cursor_excerpt)
101 .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
102 )
103}
104
105/// Extract a code block from a response.
106fn extract_codeblock(response: &str) -> Option<String> {
107 let lines: Vec<&str> = response.lines().collect();
108 for (i, line) in lines.iter().enumerate() {
109 if line.starts_with("```") {
110 let start = i + 1;
111 for (j, end_line) in lines[start..].iter().enumerate() {
112 if end_line.starts_with("```") {
113 return Some(lines[start..start + j].join("\n"));
114 }
115 }
116 return Some(lines[start..].join("\n"));
117 }
118 }
119 None
120}
121
122/// Parse the LLM response into a QaResult.
123fn parse_response(response_text: &str) -> QaResult {
124 let codeblock = extract_codeblock(response_text);
125
126 // Try parsing codeblock first, then fall back to raw response
127 for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
128 let Some(text) = text_to_parse else {
129 continue;
130 };
131
132 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
133 return QaResult {
134 reasoning: parsed
135 .get("reasoning")
136 .and_then(|v| v.as_str())
137 .map(|s| s.to_string()),
138 reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
139 confidence: parsed
140 .get("confidence")
141 .and_then(|v| v.as_u64())
142 .map(|v| v as u8),
143 response: Some(response_text.to_string()),
144 error: None,
145 };
146 }
147 }
148
149 // If all parsing attempts fail, return error
150 QaResult {
151 reasoning: Some(response_text.to_string()),
152 reverts_edits: None,
153 confidence: None,
154 response: Some(response_text.to_string()),
155 error: Some("Could not parse JSON from response".to_string()),
156 }
157}
158
159enum QaClient {
160 Anthropic(AnthropicClient),
161 OpenAi(OpenAiClient),
162}
163
164impl QaClient {
165 async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
166 match self {
167 QaClient::Anthropic(client) => {
168 let messages = vec![anthropic::Message {
169 role: anthropic::Role::User,
170 content: vec![anthropic::RequestContent::Text {
171 text: prompt.to_string(),
172 cache_control: None,
173 }],
174 }];
175 let response = client
176 .generate(model, max_tokens, messages, None, false)
177 .await?;
178 Ok(response.map(|r| {
179 r.content
180 .iter()
181 .filter_map(|c| match c {
182 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
183 _ => None,
184 })
185 .collect::<Vec<_>>()
186 .join("")
187 }))
188 }
189 QaClient::OpenAi(client) => {
190 let messages = vec![open_ai::RequestMessage::User {
191 content: open_ai::MessageContent::Plain(prompt.to_string()),
192 }];
193 let response = client
194 .generate(model, max_tokens, messages, None, false)
195 .await?;
196 Ok(response.map(|r| {
197 r.choices
198 .into_iter()
199 .filter_map(|choice| match choice.message {
200 open_ai::RequestMessage::Assistant { content, .. } => {
201 content.map(|c| match c {
202 open_ai::MessageContent::Plain(text) => text,
203 open_ai::MessageContent::Multipart(parts) => parts
204 .into_iter()
205 .filter_map(|p| match p {
206 open_ai::MessagePart::Text { text } => Some(text),
207 _ => None,
208 })
209 .collect::<Vec<_>>()
210 .join(""),
211 })
212 }
213 _ => None,
214 })
215 .collect::<Vec<_>>()
216 .join("")
217 }))
218 }
219 }
220 }
221
222 async fn sync_batches(&self) -> Result<()> {
223 match self {
224 QaClient::Anthropic(client) => client.sync_batches().await,
225 QaClient::OpenAi(client) => client.sync_batches().await,
226 }
227 }
228}
229
230/// Run the QA evaluation on a set of examples.
231pub async fn run_qa(
232 examples: &mut [Example],
233 args: &QaArgs,
234 output_path: Option<&PathBuf>,
235) -> Result<()> {
236 let model = model_for_backend(args.backend);
237 let client = match args.backend {
238 BatchProvider::Anthropic => {
239 if args.no_batch {
240 QaClient::Anthropic(AnthropicClient::plain()?)
241 } else {
242 QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
243 }
244 }
245 BatchProvider::Openai => {
246 if args.no_batch {
247 QaClient::OpenAi(OpenAiClient::plain()?)
248 } else {
249 QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
250 }
251 }
252 };
253
254 eprintln!(
255 "Using model: {}, backend: {:?}, batching: {}",
256 model, args.backend, !args.no_batch
257 );
258
259 // First pass: send requests (client handles caching internally)
260 let mut prompts: Vec<(usize, String)> = Vec::new();
261 let mut skipped_count = 0;
262
263 for (idx, example) in examples.iter().enumerate() {
264 let Some(prompt) = build_prompt(example) else {
265 skipped_count += 1;
266 continue;
267 };
268 prompts.push((idx, prompt));
269 }
270
271 if skipped_count > 0 {
272 eprintln!("Skipping {} items with missing actual_patch", skipped_count);
273 }
274
275 eprintln!("{} items to process", prompts.len());
276
277 // Process all items
278 let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
279
280 if args.no_batch {
281 // Synchronous processing
282 for (i, (idx, prompt)) in prompts.iter().enumerate() {
283 eprint!("\rProcessing {}/{}", i + 1, prompts.len());
284
285 let response = client.generate(model, 1024, prompt).await?;
286 let result = response.map(|text| parse_response(&text));
287 results.push((*idx, result));
288 }
289 eprintln!();
290 } else {
291 // Queue all for batching
292 for (idx, prompt) in &prompts {
293 let response = client.generate(model, 1024, prompt).await?;
294 let result = response.map(|text| parse_response(&text));
295 results.push((*idx, result));
296 }
297
298 // Sync batches (upload pending, download finished)
299 client.sync_batches().await?;
300
301 if args.wait {
302 eprintln!("Waiting for batch to complete...");
303 loop {
304 std::thread::sleep(std::time::Duration::from_secs(30));
305 client.sync_batches().await?;
306
307 // Re-check all items that didn't have results
308 let mut all_done = true;
309 for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
310 if results[result_idx].1.is_none() {
311 let response = client.generate(model, 1024, prompt).await?;
312 if let Some(text) = response {
313 results[result_idx] = (*idx, Some(parse_response(&text)));
314 } else {
315 all_done = false;
316 }
317 }
318 }
319
320 let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
321 if all_done {
322 break;
323 }
324 eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
325 }
326 } else {
327 let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
328 if pending_count > 0 {
329 eprintln!(
330 "Batch submitted. {} pending. Run again later to retrieve results.",
331 pending_count
332 );
333 }
334 }
335 }
336
337 // Build results map by index
338 let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
339 std::collections::HashMap::new();
340 for (idx, result) in results {
341 if let Some(r) = result {
342 results_by_idx.insert(idx, r);
343 }
344 }
345
346 // Output results
347 let mut writer: Box<dyn Write> = if let Some(path) = output_path {
348 Box::new(BufWriter::new(std::fs::File::create(path)?))
349 } else {
350 Box::new(std::io::stdout())
351 };
352
353 let mut num_total = 0;
354 let mut num_reverts_edits = 0;
355
356 for (idx, example) in examples.iter_mut().enumerate() {
357 // Skip examples that couldn't be processed
358 if build_prompt(example).is_none() {
359 continue;
360 }
361
362 let result = results_by_idx.get(&idx).cloned();
363
364 if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
365 num_reverts_edits += 1;
366 }
367 num_total += 1;
368
369 // Populate QA results for each prediction (currently only first prediction is evaluated)
370 example.qa = example
371 .predictions
372 .iter()
373 .enumerate()
374 .map(|(i, _)| if i == 0 { result.clone() } else { None })
375 .collect();
376
377 writeln!(writer, "{}", serde_json::to_string(&example)?)?;
378 }
379
380 if let Some(path) = output_path {
381 eprintln!("Results written to {}", path.display());
382 }
383
384 eprintln!("Processed: {} items", num_total);
385 if num_total > 0 {
386 eprintln!(
387 "Reverts edits: {} ({:.2}%)",
388 num_reverts_edits,
389 num_reverts_edits as f64 / num_total as f64 * 100.0
390 );
391 }
392
393 Ok(())
394}