1//! Quality assessment of predictions using LLM-as-a-judge.
2//!
3//! This module uses LLM Batch APIs to evaluate prediction quality.
4//! Caching is handled by the underlying client.
5
6use crate::BatchProvider;
7use crate::example::{Example, ExamplePrediction};
8use crate::format_prompt::extract_cursor_excerpt_from_example;
9use crate::llm_client::{LlmClient, model_for_backend};
10use crate::word_diff::unified_to_word_diff;
11use anyhow::Result;
12use serde::{Deserialize, Serialize};
13use std::io::{BufWriter, Write};
14use std::path::PathBuf;
15
16const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
17
18/// Arguments for the QA command.
19#[derive(Debug, Clone, clap::Args)]
20pub struct QaArgs {
21 /// Use synchronous API instead of batch
22 #[clap(long)]
23 pub no_batch: bool,
24
25 /// Wait for batch to complete (polls every 30s)
26 #[clap(long)]
27 pub wait: bool,
28
29 /// Which LLM provider to use (anthropic or openai)
30 #[clap(long, default_value = "openai")]
31 pub backend: BatchProvider,
32}
33
34/// Result of QA evaluation for a single prediction.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct QaResult {
37 /// Free-form reasoning from the judge.
38 #[serde(default, skip_serializing_if = "Option::is_none")]
39 pub reasoning: Option<String>,
40
41 /// Does the prediction undo/revert changes the user intentionally made?
42 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub reverts_edits: Option<bool>,
44
45 /// Confidence score (1-5) for user acceptance likelihood.
46 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub confidence: Option<u8>,
48
49 /// The raw response from the model.
50 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub response: Option<String>,
52
53 /// Error message if parsing or request failed.
54 #[serde(default, skip_serializing_if = "Option::is_none")]
55 pub error: Option<String>,
56}
57
58/// Build the assessment prompt for an example (uses first prediction).
59pub fn build_prompt(example: &Example) -> Option<String> {
60 let prediction = example.predictions.first()?;
61 build_prompt_for_prediction(example, prediction)
62}
63
64/// Build the assessment prompt for a specific prediction.
65pub fn build_prompt_for_prediction(
66 example: &Example,
67 prediction: &ExamplePrediction,
68) -> Option<String> {
69 let actual_patch = prediction.actual_patch.as_ref()?;
70 let prompt_inputs = example.prompt_inputs.as_ref()?;
71
72 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
73
74 // Format cursor excerpt (reuse from format_prompt)
75 let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
76
77 let mut edit_history = String::new();
78 for event in &prompt_inputs.edit_history {
79 match event.as_ref() {
80 zeta_prompt::Event::BufferChange {
81 path,
82 old_path,
83 diff,
84 predicted: _,
85 in_open_source_repo: _,
86 } => {
87 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
88 edit_history.push_str(&format!("+++ b{}\n", path.display()));
89 let diff_word_diff = unified_to_word_diff(diff);
90 edit_history.push_str(&diff_word_diff);
91 edit_history.push_str("\n\n");
92 }
93 }
94 }
95
96 Some(
97 PROMPT_TEMPLATE
98 .replace("{edit_history}", &edit_history)
99 .replace("{cursor_excerpt}", &cursor_excerpt)
100 .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
101 )
102}
103
104/// Extract a code block from a response.
105fn extract_codeblock(response: &str) -> Option<String> {
106 let lines: Vec<&str> = response.lines().collect();
107 for (i, line) in lines.iter().enumerate() {
108 if line.starts_with("```") {
109 let start = i + 1;
110 for (j, end_line) in lines[start..].iter().enumerate() {
111 if end_line.starts_with("```") {
112 return Some(lines[start..start + j].join("\n"));
113 }
114 }
115 return Some(lines[start..].join("\n"));
116 }
117 }
118 None
119}
120
121/// Parse the LLM response into a QaResult.
122pub(crate) fn parse_response(response_text: &str) -> QaResult {
123 let codeblock = extract_codeblock(response_text);
124
125 // Try parsing codeblock first, then fall back to raw response
126 for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
127 let Some(text) = text_to_parse else {
128 continue;
129 };
130
131 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
132 return QaResult {
133 reasoning: parsed
134 .get("reasoning")
135 .and_then(|v| v.as_str())
136 .map(|s| s.to_string()),
137 reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
138 confidence: parsed
139 .get("confidence")
140 .and_then(|v| v.as_u64())
141 .map(|v| v as u8),
142 response: Some(response_text.to_string()),
143 error: None,
144 };
145 }
146 }
147
148 // If all parsing attempts fail, return error
149 QaResult {
150 reasoning: Some(response_text.to_string()),
151 reverts_edits: None,
152 confidence: None,
153 response: Some(response_text.to_string()),
154 error: Some("Could not parse JSON from response".to_string()),
155 }
156}
157
158/// Run the QA evaluation on a set of examples.
159pub async fn run_qa(
160 examples: &mut [Example],
161 args: &QaArgs,
162 output_path: Option<&PathBuf>,
163) -> Result<()> {
164 let model = model_for_backend(args.backend);
165 let client = LlmClient::new(args.backend, !args.no_batch)?;
166
167 eprintln!(
168 "Using model: {}, backend: {:?}, batching: {}",
169 model, args.backend, !args.no_batch
170 );
171
172 // First pass: send requests (client handles caching internally)
173 let mut prompts: Vec<(usize, String)> = Vec::new();
174 let mut skipped_count = 0;
175
176 for (idx, example) in examples.iter().enumerate() {
177 let Some(prompt) = build_prompt(example) else {
178 skipped_count += 1;
179 continue;
180 };
181 prompts.push((idx, prompt));
182 }
183
184 if skipped_count > 0 {
185 eprintln!("Skipping {} items with missing actual_patch", skipped_count);
186 }
187
188 eprintln!("{} items to process", prompts.len());
189
190 // Process all items
191 let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
192
193 if args.no_batch {
194 // Synchronous processing
195 for (i, (idx, prompt)) in prompts.iter().enumerate() {
196 eprint!("\rProcessing {}/{}", i + 1, prompts.len());
197
198 let response = client.generate(model, 1024, prompt).await?;
199 let result = response.map(|text| parse_response(&text));
200 results.push((*idx, result));
201 }
202 eprintln!();
203 } else {
204 // Queue all for batching
205 for (idx, prompt) in &prompts {
206 let response = client.generate(model, 1024, prompt).await?;
207 let result = response.map(|text| parse_response(&text));
208 results.push((*idx, result));
209 }
210
211 // Sync batches (upload pending, download finished)
212 client.sync_batches().await?;
213
214 if args.wait {
215 eprintln!("Waiting for batch to complete...");
216 loop {
217 std::thread::sleep(std::time::Duration::from_secs(30));
218 client.sync_batches().await?;
219
220 // Re-check all items that didn't have results
221 let mut all_done = true;
222 for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
223 if results[result_idx].1.is_none() {
224 let response = client.generate(model, 1024, prompt).await?;
225 if let Some(text) = response {
226 results[result_idx] = (*idx, Some(parse_response(&text)));
227 } else {
228 all_done = false;
229 }
230 }
231 }
232
233 let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
234 if all_done {
235 break;
236 }
237 eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
238 }
239 } else {
240 let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
241 if pending_count > 0 {
242 eprintln!(
243 "Batch submitted. {} pending. Run again later to retrieve results.",
244 pending_count
245 );
246 }
247 }
248 }
249
250 // Build results map by index
251 let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
252 std::collections::HashMap::new();
253 for (idx, result) in results {
254 if let Some(r) = result {
255 results_by_idx.insert(idx, r);
256 }
257 }
258
259 // Output results
260 let mut writer: Box<dyn Write> = if let Some(path) = output_path {
261 Box::new(BufWriter::new(std::fs::File::create(path)?))
262 } else {
263 Box::new(std::io::stdout())
264 };
265
266 let mut num_total = 0;
267 let mut num_reverts_edits = 0;
268
269 for (idx, example) in examples.iter_mut().enumerate() {
270 // Skip examples that couldn't be processed
271 if build_prompt(example).is_none() {
272 continue;
273 }
274
275 let result = results_by_idx.get(&idx).cloned();
276
277 if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
278 num_reverts_edits += 1;
279 }
280 num_total += 1;
281
282 // Populate QA results for each prediction (currently only first prediction is evaluated)
283 example.qa = example
284 .predictions
285 .iter()
286 .enumerate()
287 .map(|(i, _)| if i == 0 { result.clone() } else { None })
288 .collect();
289
290 writeln!(writer, "{}", serde_json::to_string(&example)?)?;
291 }
292
293 if let Some(path) = output_path {
294 eprintln!("Results written to {}", path.display());
295 }
296
297 eprintln!("Processed: {} items", num_total);
298 if num_total > 0 {
299 eprintln!(
300 "Reverts edits: {} ({:.2}%)",
301 num_reverts_edits,
302 num_reverts_edits as f64 / num_total as f64 * 100.0
303 );
304 }
305
306 Ok(())
307}