qa.rs

  1//! Quality assessment of predictions using LLM-as-a-judge.
  2//!
  3//! This module uses the Anthropic Batch API to evaluate prediction quality.
  4//! Caching is handled by the underlying AnthropicClient.
  5
  6use crate::anthropic_client::AnthropicClient;
  7use crate::example::Example;
  8use crate::paths::CACHE_DIR;
  9use crate::word_diff::unified_to_word_diff;
 10use anthropic::{Message, RequestContent, Role};
 11use anyhow::Result;
 12use serde::{Deserialize, Serialize};
 13use std::io::{BufWriter, Write};
 14use std::path::PathBuf;
 15use std::sync::LazyLock;
 16
 17/// Model to use for QA evaluation.
 18const MODEL: &str = "claude-sonnet-4-5";
 19
 20/// Path to the QA cache database.
 21pub static QA_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("qa_cache.sqlite"));
 22
 23/// Arguments for the QA command.
 24#[derive(Debug, Clone, clap::Args)]
 25pub struct QaArgs {
 26    /// Use synchronous API instead of batch
 27    #[clap(long)]
 28    pub no_batch: bool,
 29
 30    /// Wait for batch to complete (polls every 30s)
 31    #[clap(long)]
 32    pub wait: bool,
 33}
 34
 35/// Result of QA evaluation for a single prediction.
 36#[derive(Debug, Clone, Serialize, Deserialize)]
 37pub struct QaResult {
 38    /// Free-form reasoning from the judge.
 39    #[serde(default, skip_serializing_if = "Option::is_none")]
 40    pub reasoning: Option<String>,
 41
 42    /// Does the prediction undo/revert changes the user intentionally made?
 43    #[serde(default, skip_serializing_if = "Option::is_none")]
 44    pub reverts_edits: Option<bool>,
 45
 46    /// Confidence score (1-5) for user acceptance likelihood.
 47    #[serde(default, skip_serializing_if = "Option::is_none")]
 48    pub confidence: Option<u8>,
 49
 50    /// The raw response from the model.
 51    #[serde(default, skip_serializing_if = "Option::is_none")]
 52    pub response: Option<String>,
 53
 54    /// Error message if parsing or request failed.
 55    #[serde(default, skip_serializing_if = "Option::is_none")]
 56    pub error: Option<String>,
 57}
 58
 59/// Build the assessment prompt for an example.
 60pub fn build_prompt(example: &Example) -> Option<String> {
 61    let prediction = example.predictions.first()?;
 62    let actual_patch = prediction.actual_patch.as_ref()?;
 63    let prompt_inputs = example.prompt_inputs.as_ref()?;
 64
 65    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 66
 67    let mut edit_history = String::new();
 68    for event in &prompt_inputs.edit_history {
 69        match event.as_ref() {
 70            zeta_prompt::Event::BufferChange {
 71                path,
 72                old_path,
 73                diff,
 74                predicted: _,
 75                in_open_source_repo: _,
 76            } => {
 77                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 78                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 79                let diff_word_diff = unified_to_word_diff(diff);
 80                edit_history.push_str(&diff_word_diff);
 81                edit_history.push_str("\n\n");
 82            }
 83        }
 84    }
 85
 86    Some(format!(
 87        r#"
 88You are evaluating an edit prediction model for a code editor. The model observes a programmer's recent edit history and predicts what edit they will make next.
 89
 90All diffs are in the word-diff format.
 91
 92The model is instructed to:
 93- Complete partially-applied refactoring or changes
 94- Maintain consistency with established patterns and style
 95- NOT delete or revert text that was just added (unless the user explicitly undid it themselves)
 96
 97## Edit History (chronological)
 98```````
 99{edit_history}
100```````
101
102## Predicted Next Edit
103```````
104{actual_patch_word_diff}
105```````
106
107## Evaluate
108
1091. **reverts_edits**: Does the prediction undo, or revert changes the user intentionally made in the **edit history**?
110
1112. **confidence**: How likely is the user to accept this suggestion?
112   - 1 = Definitely reject (wrong, nonsensical, or harmful)
113   - 2 = Probably reject (doesn't fit intent or pattern)
114   - 3 = Uncertain (plausible but not clearly correct)
115   - 4 = Probably accept (reasonable next step)
116   - 5 = Definitely accept (obvious continuation)
117
118Output JSON in this format:
119
120```
121{{
122    "reasoning": "your reasoning here",
123    "reverts_edits": true/false,
124    "confidence": 1-5
125}}
126```
127"#
128    ))
129}
130
131/// Extract a code block from a response.
132fn extract_codeblock(response: &str) -> Option<String> {
133    let lines: Vec<&str> = response.lines().collect();
134    for (i, line) in lines.iter().enumerate() {
135        if line.starts_with("```") {
136            let start = i + 1;
137            for (j, end_line) in lines[start..].iter().enumerate() {
138                if end_line.starts_with("```") {
139                    return Some(lines[start..start + j].join("\n"));
140                }
141            }
142            return Some(lines[start..].join("\n"));
143        }
144    }
145    None
146}
147
148/// Parse the LLM response into a QaResult.
149fn parse_response(response_text: &str) -> QaResult {
150    let codeblock = extract_codeblock(response_text);
151
152    // Try parsing codeblock first, then fall back to raw response
153    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
154        let Some(text) = text_to_parse else {
155            continue;
156        };
157
158        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
159            return QaResult {
160                reasoning: parsed
161                    .get("reasoning")
162                    .and_then(|v| v.as_str())
163                    .map(|s| s.to_string()),
164                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
165                confidence: parsed
166                    .get("confidence")
167                    .and_then(|v| v.as_u64())
168                    .map(|v| v as u8),
169                response: Some(response_text.to_string()),
170                error: None,
171            };
172        }
173    }
174
175    // If all parsing attempts fail, return error
176    QaResult {
177        reasoning: Some(response_text.to_string()),
178        reverts_edits: None,
179        confidence: None,
180        response: Some(response_text.to_string()),
181        error: Some("Could not parse JSON from response".to_string()),
182    }
183}
184
185/// Run the QA evaluation on a set of examples.
186pub async fn run_qa(
187    examples: &mut [Example],
188    args: &QaArgs,
189    output_path: Option<&PathBuf>,
190) -> Result<()> {
191    let client = if args.no_batch {
192        AnthropicClient::plain()?
193    } else {
194        AnthropicClient::batch(&QA_CACHE_DB)?
195    };
196
197    eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
198
199    // First pass: send requests (client handles caching internally)
200    let mut prompts: Vec<(usize, String)> = Vec::new();
201    let mut skipped_count = 0;
202
203    for (idx, example) in examples.iter().enumerate() {
204        let Some(prompt) = build_prompt(example) else {
205            skipped_count += 1;
206            continue;
207        };
208        prompts.push((idx, prompt));
209    }
210
211    if skipped_count > 0 {
212        eprintln!("Skipping {} items with missing actual_patch", skipped_count);
213    }
214
215    eprintln!("{} items to process", prompts.len());
216
217    // Process all items
218    let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
219
220    if args.no_batch {
221        // Synchronous processing
222        for (i, (idx, prompt)) in prompts.iter().enumerate() {
223            eprint!("\rProcessing {}/{}", i + 1, prompts.len());
224
225            let messages = vec![Message {
226                role: Role::User,
227                content: vec![RequestContent::Text {
228                    text: prompt.clone(),
229                    cache_control: None,
230                }],
231            }];
232
233            let response = client.generate(MODEL, 1024, messages).await?;
234            let result = response.map(|r| {
235                let text = r
236                    .content
237                    .iter()
238                    .filter_map(|c| match c {
239                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
240                        _ => None,
241                    })
242                    .collect::<Vec<_>>()
243                    .join("");
244                parse_response(&text)
245            });
246            results.push((*idx, result));
247        }
248        eprintln!();
249    } else {
250        // Queue all for batching
251        for (idx, prompt) in &prompts {
252            let messages = vec![Message {
253                role: Role::User,
254                content: vec![RequestContent::Text {
255                    text: prompt.clone(),
256                    cache_control: None,
257                }],
258            }];
259
260            let response = client.generate(MODEL, 1024, messages).await?;
261            let result = response.map(|r| {
262                let text = r
263                    .content
264                    .iter()
265                    .filter_map(|c| match c {
266                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
267                        _ => None,
268                    })
269                    .collect::<Vec<_>>()
270                    .join("");
271                parse_response(&text)
272            });
273            results.push((*idx, result));
274        }
275
276        // Sync batches (upload pending, download finished)
277        client.sync_batches().await?;
278
279        if args.wait {
280            eprintln!("Waiting for batch to complete...");
281            loop {
282                std::thread::sleep(std::time::Duration::from_secs(30));
283                client.sync_batches().await?;
284
285                // Re-check all items that didn't have results
286                let mut all_done = true;
287                for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
288                    if results[result_idx].1.is_none() {
289                        let messages = vec![Message {
290                            role: Role::User,
291                            content: vec![RequestContent::Text {
292                                text: prompt.clone(),
293                                cache_control: None,
294                            }],
295                        }];
296
297                        let response = client.generate(MODEL, 1024, messages).await?;
298                        if let Some(r) = response {
299                            let text = r
300                                .content
301                                .iter()
302                                .filter_map(|c| match c {
303                                    anthropic::ResponseContent::Text { text } => {
304                                        Some(text.as_str())
305                                    }
306                                    _ => None,
307                                })
308                                .collect::<Vec<_>>()
309                                .join("");
310                            results[result_idx] = (*idx, Some(parse_response(&text)));
311                        } else {
312                            all_done = false;
313                        }
314                    }
315                }
316
317                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
318                if all_done {
319                    break;
320                }
321                eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
322            }
323        } else {
324            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
325            if pending_count > 0 {
326                eprintln!(
327                    "Batch submitted. {} pending. Run again later to retrieve results.",
328                    pending_count
329                );
330            }
331        }
332    }
333
334    // Build results map by index
335    let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
336        std::collections::HashMap::new();
337    for (idx, result) in results {
338        if let Some(r) = result {
339            results_by_idx.insert(idx, r);
340        }
341    }
342
343    // Output results
344    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
345        Box::new(BufWriter::new(std::fs::File::create(path)?))
346    } else {
347        Box::new(std::io::stdout())
348    };
349
350    let mut num_total = 0;
351    let mut num_reverts_edits = 0;
352
353    for (idx, example) in examples.iter_mut().enumerate() {
354        // Skip examples that couldn't be processed
355        if build_prompt(example).is_none() {
356            continue;
357        }
358
359        let result = results_by_idx
360            .get(&idx)
361            .cloned()
362            .unwrap_or_else(|| QaResult {
363                reasoning: None,
364                reverts_edits: None,
365                confidence: None,
366                response: None,
367                error: Some("Result not found".to_string()),
368            });
369
370        if result.reverts_edits == Some(true) {
371            num_reverts_edits += 1;
372        }
373        num_total += 1;
374
375        // Add QA result to example and output
376        let mut example_json = serde_json::to_value(&example)?;
377        example_json["qa"] = serde_json::to_value(&result)?;
378        writeln!(writer, "{}", serde_json::to_string(&example_json)?)?;
379    }
380
381    if let Some(path) = output_path {
382        eprintln!("Results written to {}", path.display());
383    }
384
385    eprintln!("Processed:     {} items", num_total);
386    if num_total > 0 {
387        eprintln!(
388            "Reverts edits: {} ({:.2}%)",
389            num_reverts_edits,
390            num_reverts_edits as f64 / num_total as f64 * 100.0
391        );
392    }
393
394    Ok(())
395}