qa.rs

  1//! Quality assessment of predictions using LLM-as-a-judge.
  2//!
  3//! This module uses LLM Batch APIs to evaluate prediction quality.
  4//! Caching is handled by the underlying client.
  5
  6use crate::BatchProvider;
  7use crate::anthropic_client::AnthropicClient;
  8use crate::example::Example;
  9use crate::format_prompt::extract_cursor_excerpt_from_example;
 10use crate::openai_client::OpenAiClient;
 11use crate::paths::LLM_CACHE_DB;
 12use crate::word_diff::unified_to_word_diff;
 13use anyhow::Result;
 14use serde::{Deserialize, Serialize};
 15use std::io::{BufWriter, Write};
 16use std::path::PathBuf;
 17
 18const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
 19
 20/// Arguments for the QA command.
 21#[derive(Debug, Clone, clap::Args)]
 22pub struct QaArgs {
 23    /// Use synchronous API instead of batch
 24    #[clap(long)]
 25    pub no_batch: bool,
 26
 27    /// Wait for batch to complete (polls every 30s)
 28    #[clap(long)]
 29    pub wait: bool,
 30
 31    /// Which LLM provider to use (anthropic or openai)
 32    #[clap(long, default_value = "openai")]
 33    pub backend: BatchProvider,
 34}
 35
 36fn model_for_backend(backend: BatchProvider) -> &'static str {
 37    match backend {
 38        BatchProvider::Anthropic => "claude-sonnet-4-5",
 39        BatchProvider::Openai => "gpt-5.2",
 40    }
 41}
 42
 43/// Result of QA evaluation for a single prediction.
 44#[derive(Debug, Clone, Serialize, Deserialize)]
 45pub struct QaResult {
 46    /// Free-form reasoning from the judge.
 47    #[serde(default, skip_serializing_if = "Option::is_none")]
 48    pub reasoning: Option<String>,
 49
 50    /// Does the prediction undo/revert changes the user intentionally made?
 51    #[serde(default, skip_serializing_if = "Option::is_none")]
 52    pub reverts_edits: Option<bool>,
 53
 54    /// Confidence score (1-5) for user acceptance likelihood.
 55    #[serde(default, skip_serializing_if = "Option::is_none")]
 56    pub confidence: Option<u8>,
 57
 58    /// The raw response from the model.
 59    #[serde(default, skip_serializing_if = "Option::is_none")]
 60    pub response: Option<String>,
 61
 62    /// Error message if parsing or request failed.
 63    #[serde(default, skip_serializing_if = "Option::is_none")]
 64    pub error: Option<String>,
 65}
 66
 67/// Build the assessment prompt for an example.
 68pub fn build_prompt(example: &Example) -> Option<String> {
 69    let prediction = example.predictions.first()?;
 70    let actual_patch = prediction.actual_patch.as_ref()?;
 71    let prompt_inputs = example.prompt_inputs.as_ref()?;
 72
 73    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 74
 75    // Format cursor excerpt (reuse from format_prompt)
 76    let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
 77
 78    let mut edit_history = String::new();
 79    for event in &prompt_inputs.edit_history {
 80        match event.as_ref() {
 81            zeta_prompt::Event::BufferChange {
 82                path,
 83                old_path,
 84                diff,
 85                predicted: _,
 86                in_open_source_repo: _,
 87            } => {
 88                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 89                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 90                let diff_word_diff = unified_to_word_diff(diff);
 91                edit_history.push_str(&diff_word_diff);
 92                edit_history.push_str("\n\n");
 93            }
 94        }
 95    }
 96
 97    Some(
 98        PROMPT_TEMPLATE
 99            .replace("{edit_history}", &edit_history)
100            .replace("{cursor_excerpt}", &cursor_excerpt)
101            .replace("{actual_patch_word_diff}", &actual_patch_word_diff),
102    )
103}
104
105/// Extract a code block from a response.
106fn extract_codeblock(response: &str) -> Option<String> {
107    let lines: Vec<&str> = response.lines().collect();
108    for (i, line) in lines.iter().enumerate() {
109        if line.starts_with("```") {
110            let start = i + 1;
111            for (j, end_line) in lines[start..].iter().enumerate() {
112                if end_line.starts_with("```") {
113                    return Some(lines[start..start + j].join("\n"));
114                }
115            }
116            return Some(lines[start..].join("\n"));
117        }
118    }
119    None
120}
121
122/// Parse the LLM response into a QaResult.
123fn parse_response(response_text: &str) -> QaResult {
124    let codeblock = extract_codeblock(response_text);
125
126    // Try parsing codeblock first, then fall back to raw response
127    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
128        let Some(text) = text_to_parse else {
129            continue;
130        };
131
132        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
133            return QaResult {
134                reasoning: parsed
135                    .get("reasoning")
136                    .and_then(|v| v.as_str())
137                    .map(|s| s.to_string()),
138                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
139                confidence: parsed
140                    .get("confidence")
141                    .and_then(|v| v.as_u64())
142                    .map(|v| v as u8),
143                response: Some(response_text.to_string()),
144                error: None,
145            };
146        }
147    }
148
149    // If all parsing attempts fail, return error
150    QaResult {
151        reasoning: Some(response_text.to_string()),
152        reverts_edits: None,
153        confidence: None,
154        response: Some(response_text.to_string()),
155        error: Some("Could not parse JSON from response".to_string()),
156    }
157}
158
159enum QaClient {
160    Anthropic(AnthropicClient),
161    OpenAi(OpenAiClient),
162}
163
164impl QaClient {
165    async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
166        match self {
167            QaClient::Anthropic(client) => {
168                let messages = vec![anthropic::Message {
169                    role: anthropic::Role::User,
170                    content: vec![anthropic::RequestContent::Text {
171                        text: prompt.to_string(),
172                        cache_control: None,
173                    }],
174                }];
175                let response = client
176                    .generate(model, max_tokens, messages, None, false)
177                    .await?;
178                Ok(response.map(|r| {
179                    r.content
180                        .iter()
181                        .filter_map(|c| match c {
182                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
183                            _ => None,
184                        })
185                        .collect::<Vec<_>>()
186                        .join("")
187                }))
188            }
189            QaClient::OpenAi(client) => {
190                let messages = vec![open_ai::RequestMessage::User {
191                    content: open_ai::MessageContent::Plain(prompt.to_string()),
192                }];
193                let response = client
194                    .generate(model, max_tokens, messages, None, false)
195                    .await?;
196                Ok(response.map(|r| {
197                    r.choices
198                        .into_iter()
199                        .filter_map(|choice| match choice.message {
200                            open_ai::RequestMessage::Assistant { content, .. } => {
201                                content.map(|c| match c {
202                                    open_ai::MessageContent::Plain(text) => text,
203                                    open_ai::MessageContent::Multipart(parts) => parts
204                                        .into_iter()
205                                        .filter_map(|p| match p {
206                                            open_ai::MessagePart::Text { text } => Some(text),
207                                            _ => None,
208                                        })
209                                        .collect::<Vec<_>>()
210                                        .join(""),
211                                })
212                            }
213                            _ => None,
214                        })
215                        .collect::<Vec<_>>()
216                        .join("")
217                }))
218            }
219        }
220    }
221
222    async fn sync_batches(&self) -> Result<()> {
223        match self {
224            QaClient::Anthropic(client) => client.sync_batches().await,
225            QaClient::OpenAi(client) => client.sync_batches().await,
226        }
227    }
228}
229
230/// Run the QA evaluation on a set of examples.
231pub async fn run_qa(
232    examples: &mut [Example],
233    args: &QaArgs,
234    output_path: Option<&PathBuf>,
235) -> Result<()> {
236    let model = model_for_backend(args.backend);
237    let client = match args.backend {
238        BatchProvider::Anthropic => {
239            if args.no_batch {
240                QaClient::Anthropic(AnthropicClient::plain()?)
241            } else {
242                QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
243            }
244        }
245        BatchProvider::Openai => {
246            if args.no_batch {
247                QaClient::OpenAi(OpenAiClient::plain()?)
248            } else {
249                QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
250            }
251        }
252    };
253
254    eprintln!(
255        "Using model: {}, backend: {:?}, batching: {}",
256        model, args.backend, !args.no_batch
257    );
258
259    // First pass: send requests (client handles caching internally)
260    let mut prompts: Vec<(usize, String)> = Vec::new();
261    let mut skipped_count = 0;
262
263    for (idx, example) in examples.iter().enumerate() {
264        let Some(prompt) = build_prompt(example) else {
265            skipped_count += 1;
266            continue;
267        };
268        prompts.push((idx, prompt));
269    }
270
271    if skipped_count > 0 {
272        eprintln!("Skipping {} items with missing actual_patch", skipped_count);
273    }
274
275    eprintln!("{} items to process", prompts.len());
276
277    // Process all items
278    let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
279
280    if args.no_batch {
281        // Synchronous processing
282        for (i, (idx, prompt)) in prompts.iter().enumerate() {
283            eprint!("\rProcessing {}/{}", i + 1, prompts.len());
284
285            let response = client.generate(model, 1024, prompt).await?;
286            let result = response.map(|text| parse_response(&text));
287            results.push((*idx, result));
288        }
289        eprintln!();
290    } else {
291        // Queue all for batching
292        for (idx, prompt) in &prompts {
293            let response = client.generate(model, 1024, prompt).await?;
294            let result = response.map(|text| parse_response(&text));
295            results.push((*idx, result));
296        }
297
298        // Sync batches (upload pending, download finished)
299        client.sync_batches().await?;
300
301        if args.wait {
302            eprintln!("Waiting for batch to complete...");
303            loop {
304                std::thread::sleep(std::time::Duration::from_secs(30));
305                client.sync_batches().await?;
306
307                // Re-check all items that didn't have results
308                let mut all_done = true;
309                for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
310                    if results[result_idx].1.is_none() {
311                        let response = client.generate(model, 1024, prompt).await?;
312                        if let Some(text) = response {
313                            results[result_idx] = (*idx, Some(parse_response(&text)));
314                        } else {
315                            all_done = false;
316                        }
317                    }
318                }
319
320                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
321                if all_done {
322                    break;
323                }
324                eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
325            }
326        } else {
327            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
328            if pending_count > 0 {
329                eprintln!(
330                    "Batch submitted. {} pending. Run again later to retrieve results.",
331                    pending_count
332                );
333            }
334        }
335    }
336
337    // Build results map by index
338    let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
339        std::collections::HashMap::new();
340    for (idx, result) in results {
341        if let Some(r) = result {
342            results_by_idx.insert(idx, r);
343        }
344    }
345
346    // Output results
347    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
348        Box::new(BufWriter::new(std::fs::File::create(path)?))
349    } else {
350        Box::new(std::io::stdout())
351    };
352
353    let mut num_total = 0;
354    let mut num_reverts_edits = 0;
355
356    for (idx, example) in examples.iter_mut().enumerate() {
357        // Skip examples that couldn't be processed
358        if build_prompt(example).is_none() {
359            continue;
360        }
361
362        let result = results_by_idx.get(&idx).cloned();
363
364        if result.as_ref().and_then(|r| r.reverts_edits) == Some(true) {
365            num_reverts_edits += 1;
366        }
367        num_total += 1;
368
369        // Populate QA results for each prediction (currently only first prediction is evaluated)
370        example.qa = example
371            .predictions
372            .iter()
373            .enumerate()
374            .map(|(i, _)| if i == 0 { result.clone() } else { None })
375            .collect();
376
377        writeln!(writer, "{}", serde_json::to_string(&example)?)?;
378    }
379
380    if let Some(path) = output_path {
381        eprintln!("Results written to {}", path.display());
382    }
383
384    eprintln!("Processed:     {} items", num_total);
385    if num_total > 0 {
386        eprintln!(
387            "Reverts edits: {} ({:.2}%)",
388            num_reverts_edits,
389            num_reverts_edits as f64 / num_total as f64 * 100.0
390        );
391    }
392
393    Ok(())
394}