qa.rs

  1//! Quality assessment of predictions using LLM-as-a-judge.
  2//!
  3//! This module uses LLM Batch APIs to evaluate prediction quality.
  4//! Caching is handled by the underlying client.
  5
  6use crate::BatchProvider;
  7use crate::anthropic_client::AnthropicClient;
  8use crate::example::Example;
  9use crate::format_prompt::extract_cursor_excerpt_from_example;
 10use crate::openai_client::OpenAiClient;
 11use crate::paths::LLM_CACHE_DB;
 12use crate::word_diff::unified_to_word_diff;
 13use anyhow::Result;
 14use serde::{Deserialize, Serialize};
 15use std::io::{BufWriter, Write};
 16use std::path::PathBuf;
 17
 18/// Arguments for the QA command.
 19#[derive(Debug, Clone, clap::Args)]
 20pub struct QaArgs {
 21    /// Use synchronous API instead of batch
 22    #[clap(long)]
 23    pub no_batch: bool,
 24
 25    /// Wait for batch to complete (polls every 30s)
 26    #[clap(long)]
 27    pub wait: bool,
 28
 29    /// Which LLM provider to use (anthropic or openai)
 30    #[clap(long, default_value = "openai")]
 31    pub backend: BatchProvider,
 32}
 33
 34fn model_for_backend(backend: BatchProvider) -> &'static str {
 35    match backend {
 36        BatchProvider::Anthropic => "claude-sonnet-4-5",
 37        BatchProvider::Openai => "gpt-5.2",
 38    }
 39}
 40
 41/// Result of QA evaluation for a single prediction.
 42#[derive(Debug, Clone, Serialize, Deserialize)]
 43pub struct QaResult {
 44    /// Free-form reasoning from the judge.
 45    #[serde(default, skip_serializing_if = "Option::is_none")]
 46    pub reasoning: Option<String>,
 47
 48    /// Does the prediction undo/revert changes the user intentionally made?
 49    #[serde(default, skip_serializing_if = "Option::is_none")]
 50    pub reverts_edits: Option<bool>,
 51
 52    /// Confidence score (1-5) for user acceptance likelihood.
 53    #[serde(default, skip_serializing_if = "Option::is_none")]
 54    pub confidence: Option<u8>,
 55
 56    /// The raw response from the model.
 57    #[serde(default, skip_serializing_if = "Option::is_none")]
 58    pub response: Option<String>,
 59
 60    /// Error message if parsing or request failed.
 61    #[serde(default, skip_serializing_if = "Option::is_none")]
 62    pub error: Option<String>,
 63}
 64
 65/// Build the assessment prompt for an example.
 66pub fn build_prompt(example: &Example) -> Option<String> {
 67    let prediction = example.predictions.first()?;
 68    let actual_patch = prediction.actual_patch.as_ref()?;
 69    let prompt_inputs = example.prompt_inputs.as_ref()?;
 70
 71    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 72
 73    // Format cursor excerpt (reuse from format_prompt)
 74    let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
 75
 76    let mut edit_history = String::new();
 77    for event in &prompt_inputs.edit_history {
 78        match event.as_ref() {
 79            zeta_prompt::Event::BufferChange {
 80                path,
 81                old_path,
 82                diff,
 83                predicted: _,
 84                in_open_source_repo: _,
 85            } => {
 86                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 87                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 88                let diff_word_diff = unified_to_word_diff(diff);
 89                edit_history.push_str(&diff_word_diff);
 90                edit_history.push_str("\n\n");
 91            }
 92        }
 93    }
 94
 95    let prompt_template = crate::prompt_assets::get_prompt("qa.md");
 96    Some(
 97        prompt_template
 98            .replace("{edit_history}", &edit_history)
 99            .replace("{cursor_excerpt}", &cursor_excerpt)
100            .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
101    )
102}
103
104/// Extract a code block from a response.
105fn extract_codeblock(response: &str) -> Option<String> {
106    let lines: Vec<&str> = response.lines().collect();
107    for (i, line) in lines.iter().enumerate() {
108        if line.starts_with("```") {
109            let start = i + 1;
110            for (j, end_line) in lines[start..].iter().enumerate() {
111                if end_line.starts_with("```") {
112                    return Some(lines[start..start + j].join("\n"));
113                }
114            }
115            return Some(lines[start..].join("\n"));
116        }
117    }
118    None
119}
120
121/// Parse the LLM response into a QaResult.
122fn parse_response(response_text: &str) -> QaResult {
123    let codeblock = extract_codeblock(response_text);
124
125    // Try parsing codeblock first, then fall back to raw response
126    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
127        let Some(text) = text_to_parse else {
128            continue;
129        };
130
131        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
132            return QaResult {
133                reasoning: parsed
134                    .get("reasoning")
135                    .and_then(|v| v.as_str())
136                    .map(|s| s.to_string()),
137                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
138                confidence: parsed
139                    .get("confidence")
140                    .and_then(|v| v.as_u64())
141                    .map(|v| v as u8),
142                response: Some(response_text.to_string()),
143                error: None,
144            };
145        }
146    }
147
148    // If all parsing attempts fail, return error
149    QaResult {
150        reasoning: Some(response_text.to_string()),
151        reverts_edits: None,
152        confidence: None,
153        response: Some(response_text.to_string()),
154        error: Some("Could not parse JSON from response".to_string()),
155    }
156}
157
158enum QaClient {
159    Anthropic(AnthropicClient),
160    OpenAi(OpenAiClient),
161}
162
163impl QaClient {
164    async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
165        match self {
166            QaClient::Anthropic(client) => {
167                let messages = vec![anthropic::Message {
168                    role: anthropic::Role::User,
169                    content: vec![anthropic::RequestContent::Text {
170                        text: prompt.to_string(),
171                        cache_control: None,
172                    }],
173                }];
174                let response = client
175                    .generate(model, max_tokens, messages, None, false)
176                    .await?;
177                Ok(response.map(|r| {
178                    r.content
179                        .iter()
180                        .filter_map(|c| match c {
181                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
182                            _ => None,
183                        })
184                        .collect::<Vec<_>>()
185                        .join("")
186                }))
187            }
188            QaClient::OpenAi(client) => {
189                let messages = vec![open_ai::RequestMessage::User {
190                    content: open_ai::MessageContent::Plain(prompt.to_string()),
191                }];
192                let response = client
193                    .generate(model, max_tokens, messages, None, false)
194                    .await?;
195                Ok(response.map(|r| {
196                    r.choices
197                        .into_iter()
198                        .filter_map(|choice| match choice.message {
199                            open_ai::RequestMessage::Assistant { content, .. } => {
200                                content.map(|c| match c {
201                                    open_ai::MessageContent::Plain(text) => text,
202                                    open_ai::MessageContent::Multipart(parts) => parts
203                                        .into_iter()
204                                        .filter_map(|p| match p {
205                                            open_ai::MessagePart::Text { text } => Some(text),
206                                            _ => None,
207                                        })
208                                        .collect::<Vec<_>>()
209                                        .join(""),
210                                })
211                            }
212                            _ => None,
213                        })
214                        .collect::<Vec<_>>()
215                        .join("")
216                }))
217            }
218        }
219    }
220
221    async fn sync_batches(&self) -> Result<()> {
222        match self {
223            QaClient::Anthropic(client) => client.sync_batches().await,
224            QaClient::OpenAi(client) => client.sync_batches().await,
225        }
226    }
227}
228
229/// Run the QA evaluation on a set of examples.
230pub async fn run_qa(
231    examples: &mut [Example],
232    args: &QaArgs,
233    output_path: Option<&PathBuf>,
234) -> Result<()> {
235    let model = model_for_backend(args.backend);
236    let client = match args.backend {
237        BatchProvider::Anthropic => {
238            if args.no_batch {
239                QaClient::Anthropic(AnthropicClient::plain()?)
240            } else {
241                QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
242            }
243        }
244        BatchProvider::Openai => {
245            if args.no_batch {
246                QaClient::OpenAi(OpenAiClient::plain()?)
247            } else {
248                QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
249            }
250        }
251    };
252
253    eprintln!(
254        "Using model: {}, backend: {:?}, batching: {}",
255        model, args.backend, !args.no_batch
256    );
257
258    // First pass: send requests (client handles caching internally)
259    let mut prompts: Vec<(usize, String)> = Vec::new();
260    let mut skipped_count = 0;
261
262    for (idx, example) in examples.iter().enumerate() {
263        let Some(prompt) = build_prompt(example) else {
264            skipped_count += 1;
265            continue;
266        };
267        prompts.push((idx, prompt));
268    }
269
270    if skipped_count > 0 {
271        eprintln!("Skipping {} items with missing actual_patch", skipped_count);
272    }
273
274    eprintln!("{} items to process", prompts.len());
275
276    // Process all items
277    let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
278
279    if args.no_batch {
280        // Synchronous processing
281        for (i, (idx, prompt)) in prompts.iter().enumerate() {
282            eprint!("\rProcessing {}/{}", i + 1, prompts.len());
283
284            let response = client.generate(model, 1024, prompt).await?;
285            let result = response.map(|text| parse_response(&text));
286            results.push((*idx, result));
287        }
288        eprintln!();
289    } else {
290        // Queue all for batching
291        for (idx, prompt) in &prompts {
292            let response = client.generate(model, 1024, prompt).await?;
293            let result = response.map(|text| parse_response(&text));
294            results.push((*idx, result));
295        }
296
297        // Sync batches (upload pending, download finished)
298        client.sync_batches().await?;
299
300        if args.wait {
301            eprintln!("Waiting for batch to complete...");
302            loop {
303                std::thread::sleep(std::time::Duration::from_secs(30));
304                client.sync_batches().await?;
305
306                // Re-check all items that didn't have results
307                let mut all_done = true;
308                for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
309                    if results[result_idx].1.is_none() {
310                        let response = client.generate(model, 1024, prompt).await?;
311                        if let Some(text) = response {
312                            results[result_idx] = (*idx, Some(parse_response(&text)));
313                        } else {
314                            all_done = false;
315                        }
316                    }
317                }
318
319                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
320                if all_done {
321                    break;
322                }
323                eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
324            }
325        } else {
326            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
327            if pending_count > 0 {
328                eprintln!(
329                    "Batch submitted. {} pending. Run again later to retrieve results.",
330                    pending_count
331                );
332            }
333        }
334    }
335
336    // Build results map by index
337    let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
338        std::collections::HashMap::new();
339    for (idx, result) in results {
340        if let Some(r) = result {
341            results_by_idx.insert(idx, r);
342        }
343    }
344
345    // Output results
346    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
347        Box::new(BufWriter::new(std::fs::File::create(path)?))
348    } else {
349        Box::new(std::io::stdout())
350    };
351
352    let mut num_total = 0;
353    let mut num_reverts_edits = 0;
354
355    for (idx, example) in examples.iter_mut().enumerate() {
356        // Skip examples that couldn't be processed
357        if build_prompt(example).is_none() {
358            continue;
359        }
360
361        let result = results_by_idx.get(&idx).cloned();
362
363        if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
364            num_reverts_edits += 1;
365        }
366        num_total += 1;
367
368        // Populate QA results for each prediction (currently only first prediction is evaluated)
369        example.qa = example
370            .predictions
371            .iter()
372            .enumerate()
373            .map(|(i, _)| if i == 0 { result.clone() } else { None })
374            .collect();
375
376        writeln!(writer, "{}", serde_json::to_string(&example)?)?;
377    }
378
379    if let Some(path) = output_path {
380        eprintln!("Results written to {}", path.display());
381    }
382
383    eprintln!("Processed:     {} items", num_total);
384    if num_total > 0 {
385        eprintln!(
386            "Reverts edits: {} ({:.2}%)",
387            num_reverts_edits,
388            num_reverts_edits as f64 / num_total as f64 * 100.0
389        );
390    }
391
392    Ok(())
393}