qa.rs

  1//! Quality assessment of predictions using LLM-as-a-judge.
  2//!
  3//! This module uses LLM Batch APIs to evaluate prediction quality.
  4//! Caching is handled by the underlying client.
  5
  6use crate::{
  7    BatchProvider,
  8    anthropic_client::AnthropicClient,
  9    example::Example,
 10    format_prompt::extract_cursor_excerpt_from_example,
 11    openai_client::OpenAiClient,
 12    parse_output::run_parse_output,
 13    paths::LLM_CACHE_DB,
 14    progress::{ExampleProgress, Step},
 15    word_diff::unified_to_word_diff,
 16};
 17use anyhow::{Context as _, Result};
 18use serde::{Deserialize, Serialize};
 19use std::sync::OnceLock;
 20
 21/// Arguments for the QA command.
 22#[derive(Debug, Clone, clap::Args)]
 23pub struct QaArgs {
 24    /// Use synchronous API instead of batch
 25    #[clap(long)]
 26    pub no_batch: bool,
 27
 28    /// Which LLM provider to use (anthropic or openai)
 29    #[clap(long, default_value = "openai")]
 30    pub backend: BatchProvider,
 31}
 32
 33fn model_for_backend(backend: BatchProvider) -> &'static str {
 34    match backend {
 35        BatchProvider::Anthropic => "claude-sonnet-4-5",
 36        BatchProvider::Openai => "gpt-5.2",
 37    }
 38}
 39
 40/// Result of QA evaluation for a single prediction.
 41#[derive(Debug, Clone, Serialize, Deserialize)]
 42pub struct QaResult {
 43    /// Free-form reasoning from the judge.
 44    #[serde(default, skip_serializing_if = "Option::is_none")]
 45    pub reasoning: Option<String>,
 46
 47    /// Does the prediction undo/revert changes the user intentionally made?
 48    #[serde(default, skip_serializing_if = "Option::is_none")]
 49    pub reverts_edits: Option<bool>,
 50
 51    /// Confidence score (1-5) for user acceptance likelihood.
 52    #[serde(default, skip_serializing_if = "Option::is_none")]
 53    pub confidence: Option<u8>,
 54
 55    /// The raw response from the model.
 56    #[serde(default, skip_serializing_if = "Option::is_none")]
 57    pub response: Option<String>,
 58
 59    /// Error message if parsing or request failed.
 60    #[serde(default, skip_serializing_if = "Option::is_none")]
 61    pub error: Option<String>,
 62}
 63
 64/// Build the assessment prompt for an example.
 65pub fn build_prompt(example: &Example) -> Result<String> {
 66    let prediction = example
 67        .predictions
 68        .first()
 69        .context("no predictions available")?;
 70    let actual_patch = prediction
 71        .actual_patch
 72        .as_ref()
 73        .context("no actual_patch available (run predict first)")?;
 74    let prompt_inputs = example
 75        .prompt_inputs
 76        .as_ref()
 77        .context("prompt_inputs missing (run context retrieval first)")?;
 78
 79    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 80
 81    let cursor_excerpt =
 82        extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?;
 83
 84    let mut edit_history = String::new();
 85    for event in &prompt_inputs.edit_history {
 86        match event.as_ref() {
 87            zeta_prompt::Event::BufferChange {
 88                path,
 89                old_path,
 90                diff,
 91                predicted: _,
 92                in_open_source_repo: _,
 93            } => {
 94                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 95                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 96                let diff_word_diff = unified_to_word_diff(diff);
 97                edit_history.push_str(&diff_word_diff);
 98                edit_history.push_str("\n\n");
 99            }
100        }
101    }
102
103    let prompt_template = crate::prompt_assets::get_prompt("qa.md");
104    Ok(prompt_template
105        .replace("{edit_history}", &edit_history)
106        .replace("{cursor_excerpt}", &cursor_excerpt)
107        .replace("{actual_patch_word_diff}", &actual_patch_word_diff))
108}
109
110fn extract_codeblock(response: &str) -> Option<String> {
111    let lines: Vec<&str> = response.lines().collect();
112    for (i, line) in lines.iter().enumerate() {
113        if line.starts_with("```") {
114            let start = i + 1;
115            for (j, end_line) in lines[start..].iter().enumerate() {
116                if end_line.starts_with("```") {
117                    return Some(lines[start..start + j].join("\n"));
118                }
119            }
120            return Some(lines[start..].join("\n"));
121        }
122    }
123    None
124}
125
126fn parse_response(response_text: &str) -> QaResult {
127    let codeblock = extract_codeblock(response_text);
128
129    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
130        let Some(text) = text_to_parse else {
131            continue;
132        };
133
134        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
135            return QaResult {
136                reasoning: parsed
137                    .get("reasoning")
138                    .and_then(|v| v.as_str())
139                    .map(|s| s.to_string()),
140                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
141                confidence: parsed
142                    .get("confidence")
143                    .and_then(|v| v.as_u64())
144                    .map(|v| v as u8),
145                response: Some(response_text.to_string()),
146                error: None,
147            };
148        }
149    }
150
151    QaResult {
152        reasoning: Some(response_text.to_string()),
153        reverts_edits: None,
154        confidence: None,
155        response: Some(response_text.to_string()),
156        error: Some("Could not parse JSON from response".to_string()),
157    }
158}
159
160static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
161static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
162static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
163static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
164
165/// Run QA evaluation for a single example.
166pub async fn run_qa(
167    example: &mut Example,
168    args: &QaArgs,
169    example_progress: &ExampleProgress,
170) -> Result<()> {
171    if example
172        .qa
173        .first()
174        .and_then(|q| q.as_ref())
175        .and_then(|q| q.confidence)
176        .is_some()
177    {
178        return Ok(());
179    }
180
181    run_parse_output(example).context("Failed to execute run_parse_output")?;
182
183    if example.prompt_inputs.is_none() {
184        anyhow::bail!("prompt_inputs missing (run context retrieval first)");
185    }
186
187    let step_progress = example_progress.start(Step::Qa);
188
189    let model = model_for_backend(args.backend);
190    let prompt = build_prompt(example).context("Failed to build QA prompt")?;
191
192    step_progress.set_substatus("generating");
193
194    let response = match args.backend {
195        BatchProvider::Anthropic => {
196            let client = if args.no_batch {
197                ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
198                    AnthropicClient::plain().expect("Failed to create Anthropic client")
199                })
200            } else {
201                ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
202                    AnthropicClient::batch(&LLM_CACHE_DB)
203                        .expect("Failed to create Anthropic client")
204                })
205            };
206
207            let messages = vec![anthropic::Message {
208                role: anthropic::Role::User,
209                content: vec![anthropic::RequestContent::Text {
210                    text: prompt,
211                    cache_control: None,
212                }],
213            }];
214
215            let Some(response) = client.generate(model, 1024, messages, None, false).await? else {
216                return Ok(());
217            };
218
219            response
220                .content
221                .iter()
222                .filter_map(|c| match c {
223                    anthropic::ResponseContent::Text { text } => Some(text.as_str()),
224                    _ => None,
225                })
226                .collect::<Vec<_>>()
227                .join("")
228        }
229        BatchProvider::Openai => {
230            let client = if args.no_batch {
231                OPENAI_CLIENT_PLAIN
232                    .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
233            } else {
234                OPENAI_CLIENT_BATCH.get_or_init(|| {
235                    OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
236                })
237            };
238
239            let messages = vec![open_ai::RequestMessage::User {
240                content: open_ai::MessageContent::Plain(prompt),
241            }];
242
243            let Some(response) = client.generate(model, 1024, messages, None, false).await? else {
244                return Ok(());
245            };
246
247            response
248                .choices
249                .into_iter()
250                .filter_map(|choice| match choice.message {
251                    open_ai::RequestMessage::Assistant { content, .. } => {
252                        content.map(|c| match c {
253                            open_ai::MessageContent::Plain(text) => text,
254                            open_ai::MessageContent::Multipart(parts) => parts
255                                .into_iter()
256                                .filter_map(|p| match p {
257                                    open_ai::MessagePart::Text { text } => Some(text),
258                                    _ => None,
259                                })
260                                .collect::<Vec<_>>()
261                                .join(""),
262                        })
263                    }
264                    _ => None,
265                })
266                .collect::<Vec<_>>()
267                .join("")
268        }
269    };
270
271    let result = parse_response(&response);
272
273    example.qa = example
274        .predictions
275        .iter()
276        .enumerate()
277        .map(|(i, _)| if i == 0 { Some(result.clone()) } else { None })
278        .collect();
279
280    Ok(())
281}
282
283/// Sync batches for QA (upload pending requests, download finished results).
284pub async fn sync_batches(args: &QaArgs) -> Result<()> {
285    if args.no_batch {
286        return Ok(());
287    }
288
289    match args.backend {
290        BatchProvider::Anthropic => {
291            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
292                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
293            });
294            client.sync_batches().await?;
295        }
296        BatchProvider::Openai => {
297            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
298                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
299            });
300            client.sync_batches().await?;
301        }
302    }
303    Ok(())
304}