1//! Quality assessment of predictions using LLM-as-a-judge.
2//!
3//! This module uses the Anthropic Batch API to evaluate prediction quality.
4//! Caching is handled by the underlying AnthropicClient.
5
6use crate::anthropic_client::AnthropicClient;
7use crate::example::Example;
8use crate::paths::CACHE_DIR;
9use crate::word_diff::unified_to_word_diff;
10use anthropic::{Message, RequestContent, Role};
11use anyhow::Result;
12use serde::{Deserialize, Serialize};
13use std::io::{BufWriter, Write};
14use std::path::PathBuf;
15use std::sync::LazyLock;
16
17/// Model to use for QA evaluation.
18const MODEL: &str = "claude-sonnet-4-5";
19
20/// Path to the QA cache database.
21pub static QA_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("qa_cache.sqlite"));
22
23/// Arguments for the QA command.
24#[derive(Debug, Clone, clap::Args)]
25pub struct QaArgs {
26 /// Use synchronous API instead of batch
27 #[clap(long)]
28 pub no_batch: bool,
29
30 /// Wait for batch to complete (polls every 30s)
31 #[clap(long)]
32 pub wait: bool,
33}
34
35/// Result of QA evaluation for a single prediction.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct QaResult {
38 /// Free-form reasoning from the judge.
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub reasoning: Option<String>,
41
42 /// Does the prediction undo/revert changes the user intentionally made?
43 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub reverts_edits: Option<bool>,
45
46 /// Confidence score (1-5) for user acceptance likelihood.
47 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub confidence: Option<u8>,
49
50 /// The raw response from the model.
51 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub response: Option<String>,
53
54 /// Error message if parsing or request failed.
55 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub error: Option<String>,
57}
58
59/// Build the assessment prompt for an example.
60pub fn build_prompt(example: &Example) -> Option<String> {
61 let prediction = example.predictions.first()?;
62 let actual_patch = prediction.actual_patch.as_ref()?;
63 let prompt_inputs = example.prompt_inputs.as_ref()?;
64
65 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
66
67 let mut edit_history = String::new();
68 for event in &prompt_inputs.edit_history {
69 match event.as_ref() {
70 zeta_prompt::Event::BufferChange {
71 path,
72 old_path,
73 diff,
74 predicted: _,
75 in_open_source_repo: _,
76 } => {
77 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
78 edit_history.push_str(&format!("+++ b{}\n", path.display()));
79 let diff_word_diff = unified_to_word_diff(diff);
80 edit_history.push_str(&diff_word_diff);
81 edit_history.push_str("\n\n");
82 }
83 }
84 }
85
86 Some(format!(
87 r#"
88You are evaluating an edit prediction model for a code editor. The model observes a programmer's recent edit history and predicts what edit they will make next.
89
90All diffs are in the word-diff format.
91
92The model is instructed to:
93- Complete partially-applied refactoring or changes
94- Maintain consistency with established patterns and style
95- NOT delete or revert text that was just added (unless the user explicitly undid it themselves)
96
97## Edit History (chronological)
98```````
99{edit_history}
100```````
101
102## Predicted Next Edit
103```````
104{actual_patch_word_diff}
105```````
106
107## Evaluate
108
1091. **reverts_edits**: Does the prediction undo, or revert changes the user intentionally made in the **edit history**?
110
1112. **confidence**: How likely is the user to accept this suggestion?
112 - 1 = Definitely reject (wrong, nonsensical, or harmful)
113 - 2 = Probably reject (doesn't fit intent or pattern)
114 - 3 = Uncertain (plausible but not clearly correct)
115 - 4 = Probably accept (reasonable next step)
116 - 5 = Definitely accept (obvious continuation)
117
118Output JSON in this format:
119
120```
121{{
122 "reasoning": "your reasoning here",
123 "reverts_edits": true/false,
124 "confidence": 1-5
125}}
126```
127"#
128 ))
129}
130
131/// Extract a code block from a response.
132fn extract_codeblock(response: &str) -> Option<String> {
133 let lines: Vec<&str> = response.lines().collect();
134 for (i, line) in lines.iter().enumerate() {
135 if line.starts_with("```") {
136 let start = i + 1;
137 for (j, end_line) in lines[start..].iter().enumerate() {
138 if end_line.starts_with("```") {
139 return Some(lines[start..start + j].join("\n"));
140 }
141 }
142 return Some(lines[start..].join("\n"));
143 }
144 }
145 None
146}
147
148/// Parse the LLM response into a QaResult.
149fn parse_response(response_text: &str) -> QaResult {
150 let codeblock = extract_codeblock(response_text);
151
152 // Try parsing codeblock first, then fall back to raw response
153 for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
154 let Some(text) = text_to_parse else {
155 continue;
156 };
157
158 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
159 return QaResult {
160 reasoning: parsed
161 .get("reasoning")
162 .and_then(|v| v.as_str())
163 .map(|s| s.to_string()),
164 reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
165 confidence: parsed
166 .get("confidence")
167 .and_then(|v| v.as_u64())
168 .map(|v| v as u8),
169 response: Some(response_text.to_string()),
170 error: None,
171 };
172 }
173 }
174
175 // If all parsing attempts fail, return error
176 QaResult {
177 reasoning: Some(response_text.to_string()),
178 reverts_edits: None,
179 confidence: None,
180 response: Some(response_text.to_string()),
181 error: Some("Could not parse JSON from response".to_string()),
182 }
183}
184
185/// Run the QA evaluation on a set of examples.
186pub async fn run_qa(
187 examples: &mut [Example],
188 args: &QaArgs,
189 output_path: Option<&PathBuf>,
190) -> Result<()> {
191 let client = if args.no_batch {
192 AnthropicClient::plain()?
193 } else {
194 AnthropicClient::batch(&QA_CACHE_DB)?
195 };
196
197 eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
198
199 // First pass: send requests (client handles caching internally)
200 let mut prompts: Vec<(usize, String)> = Vec::new();
201 let mut skipped_count = 0;
202
203 for (idx, example) in examples.iter().enumerate() {
204 let Some(prompt) = build_prompt(example) else {
205 skipped_count += 1;
206 continue;
207 };
208 prompts.push((idx, prompt));
209 }
210
211 if skipped_count > 0 {
212 eprintln!("Skipping {} items with missing actual_patch", skipped_count);
213 }
214
215 eprintln!("{} items to process", prompts.len());
216
217 // Process all items
218 let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
219
220 if args.no_batch {
221 // Synchronous processing
222 for (i, (idx, prompt)) in prompts.iter().enumerate() {
223 eprint!("\rProcessing {}/{}", i + 1, prompts.len());
224
225 let messages = vec![Message {
226 role: Role::User,
227 content: vec![RequestContent::Text {
228 text: prompt.clone(),
229 cache_control: None,
230 }],
231 }];
232
233 let response = client.generate(MODEL, 1024, messages).await?;
234 let result = response.map(|r| {
235 let text = r
236 .content
237 .iter()
238 .filter_map(|c| match c {
239 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
240 _ => None,
241 })
242 .collect::<Vec<_>>()
243 .join("");
244 parse_response(&text)
245 });
246 results.push((*idx, result));
247 }
248 eprintln!();
249 } else {
250 // Queue all for batching
251 for (idx, prompt) in &prompts {
252 let messages = vec![Message {
253 role: Role::User,
254 content: vec![RequestContent::Text {
255 text: prompt.clone(),
256 cache_control: None,
257 }],
258 }];
259
260 let response = client.generate(MODEL, 1024, messages).await?;
261 let result = response.map(|r| {
262 let text = r
263 .content
264 .iter()
265 .filter_map(|c| match c {
266 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
267 _ => None,
268 })
269 .collect::<Vec<_>>()
270 .join("");
271 parse_response(&text)
272 });
273 results.push((*idx, result));
274 }
275
276 // Sync batches (upload pending, download finished)
277 client.sync_batches().await?;
278
279 if args.wait {
280 eprintln!("Waiting for batch to complete...");
281 loop {
282 std::thread::sleep(std::time::Duration::from_secs(30));
283 client.sync_batches().await?;
284
285 // Re-check all items that didn't have results
286 let mut all_done = true;
287 for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
288 if results[result_idx].1.is_none() {
289 let messages = vec![Message {
290 role: Role::User,
291 content: vec![RequestContent::Text {
292 text: prompt.clone(),
293 cache_control: None,
294 }],
295 }];
296
297 let response = client.generate(MODEL, 1024, messages).await?;
298 if let Some(r) = response {
299 let text = r
300 .content
301 .iter()
302 .filter_map(|c| match c {
303 anthropic::ResponseContent::Text { text } => {
304 Some(text.as_str())
305 }
306 _ => None,
307 })
308 .collect::<Vec<_>>()
309 .join("");
310 results[result_idx] = (*idx, Some(parse_response(&text)));
311 } else {
312 all_done = false;
313 }
314 }
315 }
316
317 let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
318 if all_done {
319 break;
320 }
321 eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
322 }
323 } else {
324 let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
325 if pending_count > 0 {
326 eprintln!(
327 "Batch submitted. {} pending. Run again later to retrieve results.",
328 pending_count
329 );
330 }
331 }
332 }
333
334 // Build results map by index
335 let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
336 std::collections::HashMap::new();
337 for (idx, result) in results {
338 if let Some(r) = result {
339 results_by_idx.insert(idx, r);
340 }
341 }
342
343 // Output results
344 let mut writer: Box<dyn Write> = if let Some(path) = output_path {
345 Box::new(BufWriter::new(std::fs::File::create(path)?))
346 } else {
347 Box::new(std::io::stdout())
348 };
349
350 let mut num_total = 0;
351 let mut num_reverts_edits = 0;
352
353 for (idx, example) in examples.iter_mut().enumerate() {
354 // Skip examples that couldn't be processed
355 if build_prompt(example).is_none() {
356 continue;
357 }
358
359 let result = results_by_idx
360 .get(&idx)
361 .cloned()
362 .unwrap_or_else(|| QaResult {
363 reasoning: None,
364 reverts_edits: None,
365 confidence: None,
366 response: None,
367 error: Some("Result not found".to_string()),
368 });
369
370 if result.reverts_edits == Some(true) {
371 num_reverts_edits += 1;
372 }
373 num_total += 1;
374
375 // Add QA result to example and output
376 let mut example_json = serde_json::to_value(&example)?;
377 example_json["qa"] = serde_json::to_value(&result)?;
378 writeln!(writer, "{}", serde_json::to_string(&example_json)?)?;
379 }
380
381 if let Some(path) = output_path {
382 eprintln!("Results written to {}", path.display());
383 }
384
385 eprintln!("Processed: {} items", num_total);
386 if num_total > 0 {
387 eprintln!(
388 "Reverts edits: {} ({:.2}%)",
389 num_reverts_edits,
390 num_reverts_edits as f64 / num_total as f64 * 100.0
391 );
392 }
393
394 Ok(())
395}