1//! Repair predictions that received poor QA scores.
2//!
3//! This module takes examples with predictions and QA feedback, identifies
4//! predictions that need improvement (based on reverts_edits or low confidence),
5//! and uses an LLM to generate improved predictions.
6
7use crate::BatchProvider;
8use crate::PredictionProvider;
9use crate::anthropic_client::AnthropicClient;
10use crate::example::{Example, ExamplePrediction};
11use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
12use crate::openai_client::OpenAiClient;
13use crate::paths::LLM_CACHE_DB;
14use crate::word_diff::unified_to_word_diff;
15use anyhow::Result;
16use std::io::{BufWriter, Write};
17use std::path::PathBuf;
18
19const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md");
20
21/// Arguments for the repair command.
22#[derive(Debug, Clone, clap::Args)]
23pub struct RepairArgs {
24 /// Use synchronous API instead of batch
25 #[clap(long)]
26 pub no_batch: bool,
27
28 /// Wait for batch to complete (polls every 30s)
29 #[clap(long)]
30 pub wait: bool,
31
32 /// Confidence threshold: repair predictions with confidence <= this value (1-5)
33 #[clap(long, default_value = "2")]
34 pub confidence_threshold: u8,
35
36 /// Which LLM provider to use (anthropic or openai)
37 #[clap(long, default_value = "anthropic")]
38 pub backend: BatchProvider,
39}
40
41fn model_for_backend(backend: BatchProvider) -> &'static str {
42 match backend {
43 BatchProvider::Anthropic => "claude-sonnet-4-5",
44 BatchProvider::Openai => "gpt-5.2",
45 }
46}
47
48/// Build the repair prompt for an example that needs improvement.
49///
50/// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs).
51pub fn build_repair_prompt(example: &Example) -> Option<String> {
52 let prediction = example.predictions.first()?;
53 let qa = example.qa.first()?.as_ref()?;
54 let prompt_inputs = example.prompt_inputs.as_ref()?;
55 let actual_patch = prediction.actual_patch.as_ref()?;
56
57 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
58
59 // Format edit history similar to qa.rs
60 let mut edit_history = String::new();
61 for event in &prompt_inputs.edit_history {
62 match event.as_ref() {
63 zeta_prompt::Event::BufferChange {
64 path,
65 old_path,
66 diff,
67 predicted: _,
68 in_open_source_repo: _,
69 } => {
70 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
71 edit_history.push_str(&format!("+++ b{}\n", path.display()));
72 let diff_word_diff = unified_to_word_diff(diff);
73 edit_history.push_str(&diff_word_diff);
74 edit_history.push_str("\n\n");
75 }
76 }
77 }
78
79 // Format related files context (reuse from TeacherPrompt)
80 let context = TeacherPrompt::format_context(example);
81
82 // Format cursor excerpt with editable region markers (reuse from format_prompt)
83 let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
84
85 // Get QA feedback
86 let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
87 let reverts_edits = qa
88 .reverts_edits
89 .map_or("unknown", |v| if v { "yes" } else { "no" });
90 let confidence = qa
91 .confidence
92 .map_or("unknown".to_string(), |v| v.to_string());
93
94 Some(
95 PROMPT_TEMPLATE
96 .replace("{edit_history}", &edit_history)
97 .replace("{context}", &context)
98 .replace("{cursor_excerpt}", &cursor_excerpt)
99 .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
100 .replace("{reverts_edits}", reverts_edits)
101 .replace("{confidence}", &confidence)
102 .replace("{qa_reasoning}", qa_reasoning),
103 )
104}
105
106/// Check if an example needs repair based on QA feedback.
107pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
108 let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) else {
109 return false;
110 };
111
112 // Repair if reverts_edits is true
113 if qa.reverts_edits == Some(true) {
114 return true;
115 }
116
117 // Repair if confidence is at or below threshold
118 if let Some(confidence) = qa.confidence {
119 if confidence <= confidence_threshold {
120 return true;
121 }
122 }
123
124 false
125}
126
127/// Parse the repair response into a prediction.
128fn parse_repair_response(example: &Example, response_text: &str) -> Result<ExamplePrediction> {
129 let actual_patch = TeacherPrompt::parse(example, response_text)?;
130
131 Ok(ExamplePrediction {
132 actual_patch: Some(actual_patch),
133 actual_output: response_text.to_string(),
134 error: None,
135 provider: PredictionProvider::Repair,
136 })
137}
138
139enum RepairClient {
140 Anthropic(AnthropicClient),
141 OpenAi(OpenAiClient),
142}
143
144impl RepairClient {
145 async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
146 match self {
147 RepairClient::Anthropic(client) => {
148 let messages = vec![anthropic::Message {
149 role: anthropic::Role::User,
150 content: vec![anthropic::RequestContent::Text {
151 text: prompt.to_string(),
152 cache_control: None,
153 }],
154 }];
155 let response = client.generate(model, max_tokens, messages, None).await?;
156 Ok(response.map(|r| {
157 r.content
158 .iter()
159 .filter_map(|c| match c {
160 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
161 _ => None,
162 })
163 .collect::<Vec<_>>()
164 .join("")
165 }))
166 }
167 RepairClient::OpenAi(client) => {
168 let messages = vec![open_ai::RequestMessage::User {
169 content: open_ai::MessageContent::Plain(prompt.to_string()),
170 }];
171 let response = client.generate(model, max_tokens, messages, None).await?;
172 Ok(response.map(|r| {
173 r.choices
174 .into_iter()
175 .filter_map(|choice| match choice.message {
176 open_ai::RequestMessage::Assistant { content, .. } => {
177 content.map(|c| match c {
178 open_ai::MessageContent::Plain(text) => text,
179 open_ai::MessageContent::Multipart(parts) => parts
180 .into_iter()
181 .filter_map(|p| match p {
182 open_ai::MessagePart::Text { text } => Some(text),
183 _ => None,
184 })
185 .collect::<Vec<_>>()
186 .join(""),
187 })
188 }
189 _ => None,
190 })
191 .collect::<Vec<_>>()
192 .join("")
193 }))
194 }
195 }
196 }
197
198 async fn sync_batches(&self) -> Result<()> {
199 match self {
200 RepairClient::Anthropic(client) => client.sync_batches().await,
201 RepairClient::OpenAi(client) => client.sync_batches().await,
202 }
203 }
204}
205
206/// Run the repair process on a set of examples.
207pub async fn run_repair(
208 examples: &mut [Example],
209 args: &RepairArgs,
210 output_path: Option<&PathBuf>,
211) -> Result<()> {
212 let model = model_for_backend(args.backend);
213 let client = match args.backend {
214 BatchProvider::Anthropic => {
215 if args.no_batch {
216 RepairClient::Anthropic(AnthropicClient::plain()?)
217 } else {
218 RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
219 }
220 }
221 BatchProvider::Openai => {
222 if args.no_batch {
223 RepairClient::OpenAi(OpenAiClient::plain()?)
224 } else {
225 RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
226 }
227 }
228 };
229
230 eprintln!(
231 "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}",
232 model, args.backend, !args.no_batch, args.confidence_threshold
233 );
234
235 // First pass: identify examples that need repair and build prompts
236 let mut repair_items: Vec<(usize, String)> = Vec::new();
237 let mut skipped_missing_data = 0;
238 let mut skipped_no_repair_needed = 0;
239
240 for (idx, example) in examples.iter().enumerate() {
241 // Skip if missing predictions or qa
242 if example.predictions.is_empty() || example.qa.is_empty() {
243 skipped_missing_data += 1;
244 continue;
245 }
246
247 // Skip if doesn't need repair
248 if !needs_repair(example, args.confidence_threshold) {
249 skipped_no_repair_needed += 1;
250 continue;
251 }
252
253 // Build repair prompt
254 let Some(prompt) = build_repair_prompt(example) else {
255 skipped_missing_data += 1;
256 continue;
257 };
258
259 repair_items.push((idx, prompt));
260 }
261
262 eprintln!(
263 "Skipping {} items with missing data, {} items that don't need repair",
264 skipped_missing_data, skipped_no_repair_needed
265 );
266 eprintln!("{} items to repair", repair_items.len());
267
268 // Process all items
269 let mut results: Vec<(usize, Option<String>)> = Vec::new();
270
271 if args.no_batch {
272 // Synchronous processing
273 for (i, (idx, prompt)) in repair_items.iter().enumerate() {
274 eprint!("\rProcessing {}/{}", i + 1, repair_items.len());
275
276 let response = client.generate(model, 16384, prompt).await?;
277 results.push((*idx, response));
278 }
279 eprintln!();
280 } else {
281 // Queue all for batching
282 for (idx, prompt) in &repair_items {
283 let response = client.generate(model, 16384, prompt).await?;
284 results.push((*idx, response));
285 }
286
287 // Sync batches (upload pending, download finished)
288 client.sync_batches().await?;
289
290 if args.wait {
291 eprintln!("Waiting for batch to complete...");
292 loop {
293 std::thread::sleep(std::time::Duration::from_secs(30));
294 client.sync_batches().await?;
295
296 // Re-check all items that didn't have results
297 let mut all_done = true;
298 for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() {
299 if results[result_idx].1.is_none() {
300 let response = client.generate(model, 16384, prompt).await?;
301 if let Some(text) = response {
302 results[result_idx] = (*idx, Some(text));
303 } else {
304 all_done = false;
305 }
306 }
307 }
308
309 let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
310 if all_done {
311 break;
312 }
313 eprintln!(
314 "Still waiting... {}/{} results",
315 done_count,
316 repair_items.len()
317 );
318 }
319 } else {
320 let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
321 if pending_count > 0 {
322 eprintln!(
323 "Batch submitted. {} pending. Run again later to retrieve results.",
324 pending_count
325 );
326 }
327 }
328 }
329
330 // Build results map by index
331 let mut results_by_idx: std::collections::HashMap<usize, String> =
332 std::collections::HashMap::new();
333 for (idx, result) in results {
334 if let Some(r) = result {
335 results_by_idx.insert(idx, r);
336 }
337 }
338
339 // Output results
340 let mut writer: Box<dyn Write> = if let Some(path) = output_path {
341 Box::new(BufWriter::new(std::fs::File::create(path)?))
342 } else {
343 Box::new(std::io::stdout())
344 };
345
346 let mut num_repaired = 0;
347 let mut num_repair_errors = 0;
348
349 for (idx, example) in examples.iter_mut().enumerate() {
350 // Add repair prediction if we have a result
351 if let Some(response_text) = results_by_idx.get(&idx) {
352 match parse_repair_response(example, response_text) {
353 Ok(prediction) => {
354 example.predictions.push(prediction);
355 num_repaired += 1;
356 }
357 Err(e) => {
358 // Add error prediction
359 example.predictions.push(ExamplePrediction {
360 actual_patch: None,
361 actual_output: response_text.clone(),
362 error: Some(format!("Failed to parse repair response: {}", e)),
363 provider: PredictionProvider::Repair,
364 });
365 num_repair_errors += 1;
366 }
367 }
368 }
369
370 writeln!(writer, "{}", serde_json::to_string(&example)?)?;
371 }
372
373 if let Some(path) = output_path {
374 eprintln!("Results written to {}", path.display());
375 }
376
377 eprintln!("Repaired: {} items", num_repaired);
378 eprintln!("Repair errors: {} items", num_repair_errors);
379
380 Ok(())
381}