qa.rs

  1//! Quality assessment of predictions using LLM-as-a-judge.
  2//!
  3//! This module uses LLM Batch APIs to evaluate prediction quality.
  4//! Caching is handled by the underlying client.
  5
  6use crate::BatchProvider;
  7use crate::example::{Example, ExamplePrediction};
  8use crate::format_prompt::extract_cursor_excerpt_from_example;
  9use crate::llm_client::{LlmClient, model_for_backend};
 10use crate::word_diff::unified_to_word_diff;
 11use anyhow::Result;
 12use serde::{Deserialize, Serialize};
 13use std::io::{BufWriter, Write};
 14use std::path::PathBuf;
 15
 16const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
 17
 18/// Arguments for the QA command.
 19#[derive(Debug, Clone, clap::Args)]
 20pub struct QaArgs {
 21    /// Use synchronous API instead of batch
 22    #[clap(long)]
 23    pub no_batch: bool,
 24
 25    /// Wait for batch to complete (polls every 30s)
 26    #[clap(long)]
 27    pub wait: bool,
 28
 29    /// Which LLM provider to use (anthropic or openai)
 30    #[clap(long, default_value = "openai")]
 31    pub backend: BatchProvider,
 32}
 33
 34/// Result of QA evaluation for a single prediction.
 35#[derive(Debug, Clone, Serialize, Deserialize)]
 36pub struct QaResult {
 37    /// Free-form reasoning from the judge.
 38    #[serde(default, skip_serializing_if = "Option::is_none")]
 39    pub reasoning: Option<String>,
 40
 41    /// Does the prediction undo/revert changes the user intentionally made?
 42    #[serde(default, skip_serializing_if = "Option::is_none")]
 43    pub reverts_edits: Option<bool>,
 44
 45    /// Confidence score (1-5) for user acceptance likelihood.
 46    #[serde(default, skip_serializing_if = "Option::is_none")]
 47    pub confidence: Option<u8>,
 48
 49    /// The raw response from the model.
 50    #[serde(default, skip_serializing_if = "Option::is_none")]
 51    pub response: Option<String>,
 52
 53    /// Error message if parsing or request failed.
 54    #[serde(default, skip_serializing_if = "Option::is_none")]
 55    pub error: Option<String>,
 56}
 57
 58/// Build the assessment prompt for an example (uses first prediction).
 59pub fn build_prompt(example: &Example) -> Option<String> {
 60    let prediction = example.predictions.first()?;
 61    build_prompt_for_prediction(example, prediction)
 62}
 63
 64/// Build the assessment prompt for a specific prediction.
 65pub fn build_prompt_for_prediction(
 66    example: &Example,
 67    prediction: &ExamplePrediction,
 68) -> Option<String> {
 69    let actual_patch = prediction.actual_patch.as_ref()?;
 70    let prompt_inputs = example.prompt_inputs.as_ref()?;
 71
 72    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 73
 74    // Format cursor excerpt (reuse from format_prompt)
 75    let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
 76
 77    let mut edit_history = String::new();
 78    for event in &prompt_inputs.edit_history {
 79        match event.as_ref() {
 80            zeta_prompt::Event::BufferChange {
 81                path,
 82                old_path,
 83                diff,
 84                predicted: _,
 85                in_open_source_repo: _,
 86            } => {
 87                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 88                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 89                let diff_word_diff = unified_to_word_diff(diff);
 90                edit_history.push_str(&diff_word_diff);
 91                edit_history.push_str("\n\n");
 92            }
 93        }
 94    }
 95
 96    Some(
 97        PROMPT_TEMPLATE
 98            .replace("{edit_history}", &edit_history)
 99            .replace("{cursor_excerpt}", &cursor_excerpt)
100            .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
101    )
102}
103
104/// Extract a code block from a response.
105fn extract_codeblock(response: &str) -> Option<String> {
106    let lines: Vec<&str> = response.lines().collect();
107    for (i, line) in lines.iter().enumerate() {
108        if line.starts_with("```") {
109            let start = i + 1;
110            for (j, end_line) in lines[start..].iter().enumerate() {
111                if end_line.starts_with("```") {
112                    return Some(lines[start..start + j].join("\n"));
113                }
114            }
115            return Some(lines[start..].join("\n"));
116        }
117    }
118    None
119}
120
121/// Parse the LLM response into a QaResult.
122pub(crate) fn parse_response(response_text: &str) -> QaResult {
123    let codeblock = extract_codeblock(response_text);
124
125    // Try parsing codeblock first, then fall back to raw response
126    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
127        let Some(text) = text_to_parse else {
128            continue;
129        };
130
131        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
132            return QaResult {
133                reasoning: parsed
134                    .get("reasoning")
135                    .and_then(|v| v.as_str())
136                    .map(|s| s.to_string()),
137                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
138                confidence: parsed
139                    .get("confidence")
140                    .and_then(|v| v.as_u64())
141                    .map(|v| v as u8),
142                response: Some(response_text.to_string()),
143                error: None,
144            };
145        }
146    }
147
148    // If all parsing attempts fail, return error
149    QaResult {
150        reasoning: Some(response_text.to_string()),
151        reverts_edits: None,
152        confidence: None,
153        response: Some(response_text.to_string()),
154        error: Some("Could not parse JSON from response".to_string()),
155    }
156}
157
158/// Run the QA evaluation on a set of examples.
159pub async fn run_qa(
160    examples: &mut [Example],
161    args: &QaArgs,
162    output_path: Option<&PathBuf>,
163) -> Result<()> {
164    let model = model_for_backend(args.backend);
165    let client = LlmClient::new(args.backend, !args.no_batch)?;
166
167    eprintln!(
168        "Using model: {}, backend: {:?}, batching: {}",
169        model, args.backend, !args.no_batch
170    );
171
172    // First pass: send requests (client handles caching internally)
173    let mut prompts: Vec<(usize, String)> = Vec::new();
174    let mut skipped_count = 0;
175
176    for (idx, example) in examples.iter().enumerate() {
177        let Some(prompt) = build_prompt(example) else {
178            skipped_count += 1;
179            continue;
180        };
181        prompts.push((idx, prompt));
182    }
183
184    if skipped_count > 0 {
185        eprintln!("Skipping {} items with missing actual_patch", skipped_count);
186    }
187
188    eprintln!("{} items to process", prompts.len());
189
190    // Process all items
191    let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
192
193    if args.no_batch {
194        // Synchronous processing
195        for (i, (idx, prompt)) in prompts.iter().enumerate() {
196            eprint!("\rProcessing {}/{}", i + 1, prompts.len());
197
198            let response = client.generate(model, 1024, prompt).await?;
199            let result = response.map(|text| parse_response(&text));
200            results.push((*idx, result));
201        }
202        eprintln!();
203    } else {
204        // Queue all for batching
205        for (idx, prompt) in &prompts {
206            let response = client.generate(model, 1024, prompt).await?;
207            let result = response.map(|text| parse_response(&text));
208            results.push((*idx, result));
209        }
210
211        // Sync batches (upload pending, download finished)
212        client.sync_batches().await?;
213
214        if args.wait {
215            eprintln!("Waiting for batch to complete...");
216            loop {
217                std::thread::sleep(std::time::Duration::from_secs(30));
218                client.sync_batches().await?;
219
220                // Re-check all items that didn't have results
221                let mut all_done = true;
222                for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
223                    if results[result_idx].1.is_none() {
224                        let response = client.generate(model, 1024, prompt).await?;
225                        if let Some(text) = response {
226                            results[result_idx] = (*idx, Some(parse_response(&text)));
227                        } else {
228                            all_done = false;
229                        }
230                    }
231                }
232
233                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
234                if all_done {
235                    break;
236                }
237                eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
238            }
239        } else {
240            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
241            if pending_count > 0 {
242                eprintln!(
243                    "Batch submitted. {} pending. Run again later to retrieve results.",
244                    pending_count
245                );
246            }
247        }
248    }
249
250    // Build results map by index
251    let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
252        std::collections::HashMap::new();
253    for (idx, result) in results {
254        if let Some(r) = result {
255            results_by_idx.insert(idx, r);
256        }
257    }
258
259    // Output results
260    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
261        Box::new(BufWriter::new(std::fs::File::create(path)?))
262    } else {
263        Box::new(std::io::stdout())
264    };
265
266    let mut num_total = 0;
267    let mut num_reverts_edits = 0;
268
269    for (idx, example) in examples.iter_mut().enumerate() {
270        // Skip examples that couldn't be processed
271        if build_prompt(example).is_none() {
272            continue;
273        }
274
275        let result = results_by_idx.get(&idx).cloned();
276
277        if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
278            num_reverts_edits += 1;
279        }
280        num_total += 1;
281
282        // Populate QA results for each prediction (currently only first prediction is evaluated)
283        example.qa = example
284            .predictions
285            .iter()
286            .enumerate()
287            .map(|(i, _)| if i == 0 { result.clone() } else { None })
288            .collect();
289
290        writeln!(writer, "{}", serde_json::to_string(&example)?)?;
291    }
292
293    if let Some(path) = output_path {
294        eprintln!("Results written to {}", path.display());
295    }
296
297    eprintln!("Processed:     {} items", num_total);
298    if num_total > 0 {
299        eprintln!(
300            "Reverts edits: {} ({:.2}%)",
301            num_reverts_edits,
302            num_reverts_edits as f64 / num_total as f64 * 100.0
303        );
304    }
305
306    Ok(())
307}