qa.rs

  1//! Quality assessment of predictions using LLM-as-a-judge.
  2//!
  3//! This module uses LLM Batch APIs to evaluate prediction quality.
  4//! Caching is handled by the underlying client.
  5
  6use crate::{
  7    BatchProvider,
  8    anthropic_client::AnthropicClient,
  9    example::Example,
 10    format_prompt::extract_cursor_excerpt_from_example,
 11    openai_client::OpenAiClient,
 12    parse_output::run_parse_output,
 13    paths::LLM_CACHE_DB,
 14    progress::{ExampleProgress, Step},
 15    word_diff::unified_to_word_diff,
 16};
 17use anyhow::{Context as _, Result};
 18use serde::{Deserialize, Serialize};
 19use std::sync::OnceLock;
 20
 21/// Arguments for the QA command.
 22#[derive(Debug, Clone, clap::Args)]
 23pub struct QaArgs {
 24    /// Use synchronous API instead of batch
 25    #[clap(long)]
 26    pub no_batch: bool,
 27
 28    /// Which LLM provider to use (anthropic or openai)
 29    #[clap(long, default_value = "openai")]
 30    pub backend: BatchProvider,
 31}
 32
 33fn model_for_backend(backend: BatchProvider) -> &'static str {
 34    match backend {
 35        BatchProvider::Anthropic => "claude-sonnet-4-5",
 36        BatchProvider::Openai => "gpt-5.2",
 37    }
 38}
 39
 40/// Result of QA evaluation for a single prediction.
 41#[derive(Debug, Clone, Serialize, Deserialize)]
 42pub struct QaResult {
 43    /// Free-form reasoning from the judge.
 44    #[serde(default, skip_serializing_if = "Option::is_none")]
 45    pub reasoning: Option<String>,
 46
 47    /// Does the prediction undo/revert changes the user intentionally made?
 48    #[serde(default, skip_serializing_if = "Option::is_none")]
 49    pub reverts_edits: Option<bool>,
 50
 51    /// Confidence score (1-5) for user acceptance likelihood.
 52    #[serde(default, skip_serializing_if = "Option::is_none")]
 53    pub confidence: Option<u8>,
 54
 55    /// The raw response from the model.
 56    #[serde(default, skip_serializing_if = "Option::is_none")]
 57    pub response: Option<String>,
 58
 59    /// Error message if parsing or request failed.
 60    #[serde(default, skip_serializing_if = "Option::is_none")]
 61    pub error: Option<String>,
 62}
 63
 64/// Build the assessment prompt for an example.
 65pub fn build_prompt(example: &Example) -> Result<String> {
 66    let prediction = example
 67        .predictions
 68        .first()
 69        .context("no predictions available")?;
 70    let actual_patch = prediction
 71        .actual_patch
 72        .as_ref()
 73        .context("no actual_patch available (run predict first)")?;
 74    let prompt_inputs = example
 75        .prompt_inputs
 76        .as_ref()
 77        .context("prompt_inputs missing (run context retrieval first)")?;
 78
 79    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 80
 81    let cursor_excerpt =
 82        extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?;
 83
 84    let mut edit_history = String::new();
 85    for event in &prompt_inputs.events {
 86        let zeta_prompt::Event::BufferChange {
 87            path,
 88            old_path,
 89            diff,
 90            predicted: _,
 91            in_open_source_repo: _,
 92        } = event.as_ref();
 93        edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 94        edit_history.push_str(&format!("+++ b{}\n", path.display()));
 95        let diff_word_diff = unified_to_word_diff(&diff);
 96        edit_history.push_str(&diff_word_diff);
 97        edit_history.push_str("\n\n");
 98    }
 99
100    let prompt_template = crate::prompt_assets::get_prompt("qa.md");
101    Ok(prompt_template
102        .replace("{edit_history}", &edit_history)
103        .replace("{cursor_excerpt}", &cursor_excerpt)
104        .replace("{actual_patch_word_diff}", &actual_patch_word_diff))
105}
106
107fn extract_codeblock(response: &str) -> Option<String> {
108    let lines: Vec<&str> = response.lines().collect();
109    for (i, line) in lines.iter().enumerate() {
110        if line.starts_with("```") {
111            let start = i + 1;
112            for (j, end_line) in lines[start..].iter().enumerate() {
113                if end_line.starts_with("```") {
114                    return Some(lines[start..start + j].join("\n"));
115                }
116            }
117            return Some(lines[start..].join("\n"));
118        }
119    }
120    None
121}
122
123fn parse_response(response_text: &str) -> QaResult {
124    let codeblock = extract_codeblock(response_text);
125
126    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
127        let Some(text) = text_to_parse else {
128            continue;
129        };
130
131        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
132            return QaResult {
133                reasoning: parsed
134                    .get("reasoning")
135                    .and_then(|v| v.as_str())
136                    .map(|s| s.to_string()),
137                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
138                confidence: parsed
139                    .get("confidence")
140                    .and_then(|v| v.as_u64())
141                    .map(|v| v as u8),
142                response: Some(response_text.to_string()),
143                error: None,
144            };
145        }
146    }
147
148    QaResult {
149        reasoning: Some(response_text.to_string()),
150        reverts_edits: None,
151        confidence: None,
152        response: Some(response_text.to_string()),
153        error: Some("Could not parse JSON from response".to_string()),
154    }
155}
156
157static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
158static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
159static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
160static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
161
162/// Run QA evaluation for a single example.
163pub async fn run_qa(
164    example: &mut Example,
165    args: &QaArgs,
166    example_progress: &ExampleProgress,
167) -> Result<()> {
168    if example
169        .qa
170        .first()
171        .and_then(|q| q.as_ref())
172        .and_then(|q| q.confidence)
173        .is_some()
174    {
175        return Ok(());
176    }
177
178    run_parse_output(example).context("Failed to execute run_parse_output")?;
179
180    if example.prompt_inputs.is_none() {
181        anyhow::bail!("prompt_inputs missing (run context retrieval first)");
182    }
183
184    let step_progress = example_progress.start(Step::Qa);
185
186    let model = model_for_backend(args.backend);
187    let prompt = build_prompt(example).context("Failed to build QA prompt")?;
188
189    step_progress.set_substatus("generating");
190
191    let response = match args.backend {
192        BatchProvider::Anthropic => {
193            let client = if args.no_batch {
194                ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
195                    AnthropicClient::plain().expect("Failed to create Anthropic client")
196                })
197            } else {
198                ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
199                    AnthropicClient::batch(&LLM_CACHE_DB)
200                        .expect("Failed to create Anthropic client")
201                })
202            };
203
204            let messages = vec![anthropic::Message {
205                role: anthropic::Role::User,
206                content: vec![anthropic::RequestContent::Text {
207                    text: prompt,
208                    cache_control: None,
209                }],
210            }];
211
212            let Some(response) = client.generate(model, 1024, messages, None, false).await? else {
213                return Ok(());
214            };
215
216            response
217                .content
218                .iter()
219                .filter_map(|c| match c {
220                    anthropic::ResponseContent::Text { text } => Some(text.as_str()),
221                    _ => None,
222                })
223                .collect::<Vec<_>>()
224                .join("")
225        }
226        BatchProvider::Openai => {
227            let client = if args.no_batch {
228                OPENAI_CLIENT_PLAIN
229                    .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
230            } else {
231                OPENAI_CLIENT_BATCH.get_or_init(|| {
232                    OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
233                })
234            };
235
236            let messages = vec![open_ai::RequestMessage::User {
237                content: open_ai::MessageContent::Plain(prompt),
238            }];
239
240            let Some(response) = client.generate(model, 1024, messages, None, false).await? else {
241                return Ok(());
242            };
243
244            response
245                .choices
246                .into_iter()
247                .filter_map(|choice| match choice.message {
248                    open_ai::RequestMessage::Assistant { content, .. } => {
249                        content.map(|c| match c {
250                            open_ai::MessageContent::Plain(text) => text,
251                            open_ai::MessageContent::Multipart(parts) => parts
252                                .into_iter()
253                                .filter_map(|p| match p {
254                                    open_ai::MessagePart::Text { text } => Some(text),
255                                    _ => None,
256                                })
257                                .collect::<Vec<_>>()
258                                .join(""),
259                        })
260                    }
261                    _ => None,
262                })
263                .collect::<Vec<_>>()
264                .join("")
265        }
266    };
267
268    let result = parse_response(&response);
269
270    example.qa = example
271        .predictions
272        .iter()
273        .enumerate()
274        .map(|(i, _)| if i == 0 { Some(result.clone()) } else { None })
275        .collect();
276
277    Ok(())
278}
279
280/// Sync batches for QA (upload pending requests, download finished results).
281pub async fn sync_batches(args: &QaArgs) -> Result<()> {
282    if args.no_batch {
283        return Ok(());
284    }
285
286    match args.backend {
287        BatchProvider::Anthropic => {
288            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
289                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
290            });
291            client.sync_batches().await?;
292        }
293        BatchProvider::Openai => {
294            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
295                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
296            });
297            client.sync_batches().await?;
298        }
299    }
300    Ok(())
301}