1//! Repair predictions that received poor QA scores.
2//!
3//! This module takes examples with predictions and QA feedback, identifies
4//! predictions that need improvement (based on reverts_edits or low confidence),
5//! and uses an LLM to generate improved predictions.
6
7use crate::PredictionProvider;
8use crate::anthropic_client::AnthropicClient;
9use crate::example::{Example, ExamplePrediction};
10use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
11use crate::paths::LLM_CACHE_DB;
12use crate::word_diff::unified_to_word_diff;
13use anthropic::{Message, RequestContent, Role};
14use anyhow::Result;
15use std::io::{BufWriter, Write};
16use std::path::PathBuf;
17
18/// Model to use for repair.
19const MODEL: &str = "claude-sonnet-4-5";
20
21const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md");
22
23/// Arguments for the repair command.
24#[derive(Debug, Clone, clap::Args)]
25pub struct RepairArgs {
26 /// Use synchronous API instead of batch
27 #[clap(long)]
28 pub no_batch: bool,
29
30 /// Wait for batch to complete (polls every 30s)
31 #[clap(long)]
32 pub wait: bool,
33
34 /// Confidence threshold: repair predictions with confidence <= this value (1-5)
35 #[clap(long, default_value = "2")]
36 pub confidence_threshold: u8,
37}
38
39/// Build the repair prompt for an example that needs improvement.
40///
41/// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs).
42pub fn build_repair_prompt(example: &Example) -> Option<String> {
43 let prediction = example.predictions.first()?;
44 let qa = example.qa.first()?.as_ref()?;
45 let prompt_inputs = example.prompt_inputs.as_ref()?;
46 let actual_patch = prediction.actual_patch.as_ref()?;
47
48 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
49
50 // Format edit history similar to qa.rs
51 let mut edit_history = String::new();
52 for event in &prompt_inputs.edit_history {
53 match event.as_ref() {
54 zeta_prompt::Event::BufferChange {
55 path,
56 old_path,
57 diff,
58 predicted: _,
59 in_open_source_repo: _,
60 } => {
61 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
62 edit_history.push_str(&format!("+++ b{}\n", path.display()));
63 let diff_word_diff = unified_to_word_diff(diff);
64 edit_history.push_str(&diff_word_diff);
65 edit_history.push_str("\n\n");
66 }
67 }
68 }
69
70 // Format related files context (reuse from TeacherPrompt)
71 let context = TeacherPrompt::format_context(example);
72
73 // Format cursor excerpt with editable region markers (reuse from format_prompt)
74 let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
75
76 // Get QA feedback
77 let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
78 let reverts_edits = qa
79 .reverts_edits
80 .map_or("unknown", |v| if v { "yes" } else { "no" });
81 let confidence = qa
82 .confidence
83 .map_or("unknown".to_string(), |v| v.to_string());
84
85 Some(
86 PROMPT_TEMPLATE
87 .replace("{edit_history}", &edit_history)
88 .replace("{context}", &context)
89 .replace("{cursor_excerpt}", &cursor_excerpt)
90 .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
91 .replace("{reverts_edits}", reverts_edits)
92 .replace("{confidence}", &confidence)
93 .replace("{qa_reasoning}", qa_reasoning),
94 )
95}
96
97/// Check if an example needs repair based on QA feedback.
98pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
99 let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) else {
100 return false;
101 };
102
103 // Repair if reverts_edits is true
104 if qa.reverts_edits == Some(true) {
105 return true;
106 }
107
108 // Repair if confidence is at or below threshold
109 if let Some(confidence) = qa.confidence {
110 if confidence <= confidence_threshold {
111 return true;
112 }
113 }
114
115 false
116}
117
118/// Parse the repair response into a prediction.
119fn parse_repair_response(example: &Example, response_text: &str) -> Result<ExamplePrediction> {
120 let actual_patch = TeacherPrompt::parse(example, response_text)?;
121
122 Ok(ExamplePrediction {
123 actual_patch: Some(actual_patch),
124 actual_output: response_text.to_string(),
125 error: None,
126 provider: PredictionProvider::Repair,
127 })
128}
129
130/// Run the repair process on a set of examples.
131pub async fn run_repair(
132 examples: &mut [Example],
133 args: &RepairArgs,
134 output_path: Option<&PathBuf>,
135) -> Result<()> {
136 let client = if args.no_batch {
137 AnthropicClient::plain()?
138 } else {
139 AnthropicClient::batch(&LLM_CACHE_DB)?
140 };
141
142 eprintln!(
143 "Using model: {}, batching: {}, confidence_threshold: {}",
144 MODEL, !args.no_batch, args.confidence_threshold
145 );
146
147 // First pass: identify examples that need repair and build prompts
148 let mut repair_items: Vec<(usize, String)> = Vec::new();
149 let mut skipped_missing_data = 0;
150 let mut skipped_no_repair_needed = 0;
151
152 for (idx, example) in examples.iter().enumerate() {
153 // Skip if missing predictions or qa
154 if example.predictions.is_empty() || example.qa.is_empty() {
155 skipped_missing_data += 1;
156 continue;
157 }
158
159 // Skip if doesn't need repair
160 if !needs_repair(example, args.confidence_threshold) {
161 skipped_no_repair_needed += 1;
162 continue;
163 }
164
165 // Build repair prompt
166 let Some(prompt) = build_repair_prompt(example) else {
167 skipped_missing_data += 1;
168 continue;
169 };
170
171 repair_items.push((idx, prompt));
172 }
173
174 eprintln!(
175 "Skipping {} items with missing data, {} items that don't need repair",
176 skipped_missing_data, skipped_no_repair_needed
177 );
178 eprintln!("{} items to repair", repair_items.len());
179
180 // Process all items
181 let mut results: Vec<(usize, Option<String>)> = Vec::new();
182
183 if args.no_batch {
184 // Synchronous processing
185 for (i, (idx, prompt)) in repair_items.iter().enumerate() {
186 eprint!("\rProcessing {}/{}", i + 1, repair_items.len());
187
188 let messages = vec![Message {
189 role: Role::User,
190 content: vec![RequestContent::Text {
191 text: prompt.clone(),
192 cache_control: None,
193 }],
194 }];
195
196 let response = client.generate(MODEL, 16384, messages).await?;
197 let result = response.map(|r| {
198 r.content
199 .iter()
200 .filter_map(|c| match c {
201 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
202 _ => None,
203 })
204 .collect::<Vec<_>>()
205 .join("")
206 });
207 results.push((*idx, result));
208 }
209 eprintln!();
210 } else {
211 // Queue all for batching
212 for (idx, prompt) in &repair_items {
213 let messages = vec![Message {
214 role: Role::User,
215 content: vec![RequestContent::Text {
216 text: prompt.clone(),
217 cache_control: None,
218 }],
219 }];
220
221 let response = client.generate(MODEL, 16384, messages).await?;
222 let result = response.map(|r| {
223 r.content
224 .iter()
225 .filter_map(|c| match c {
226 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
227 _ => None,
228 })
229 .collect::<Vec<_>>()
230 .join("")
231 });
232 results.push((*idx, result));
233 }
234
235 // Sync batches (upload pending, download finished)
236 client.sync_batches().await?;
237
238 if args.wait {
239 eprintln!("Waiting for batch to complete...");
240 loop {
241 std::thread::sleep(std::time::Duration::from_secs(30));
242 client.sync_batches().await?;
243
244 // Re-check all items that didn't have results
245 let mut all_done = true;
246 for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() {
247 if results[result_idx].1.is_none() {
248 let messages = vec![Message {
249 role: Role::User,
250 content: vec![RequestContent::Text {
251 text: prompt.clone(),
252 cache_control: None,
253 }],
254 }];
255
256 let response = client.generate(MODEL, 16384, messages).await?;
257 if let Some(r) = response {
258 let text = r
259 .content
260 .iter()
261 .filter_map(|c| match c {
262 anthropic::ResponseContent::Text { text } => {
263 Some(text.as_str())
264 }
265 _ => None,
266 })
267 .collect::<Vec<_>>()
268 .join("");
269 results[result_idx] = (*idx, Some(text));
270 } else {
271 all_done = false;
272 }
273 }
274 }
275
276 let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
277 if all_done {
278 break;
279 }
280 eprintln!(
281 "Still waiting... {}/{} results",
282 done_count,
283 repair_items.len()
284 );
285 }
286 } else {
287 let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
288 if pending_count > 0 {
289 eprintln!(
290 "Batch submitted. {} pending. Run again later to retrieve results.",
291 pending_count
292 );
293 }
294 }
295 }
296
297 // Build results map by index
298 let mut results_by_idx: std::collections::HashMap<usize, String> =
299 std::collections::HashMap::new();
300 for (idx, result) in results {
301 if let Some(r) = result {
302 results_by_idx.insert(idx, r);
303 }
304 }
305
306 // Output results
307 let mut writer: Box<dyn Write> = if let Some(path) = output_path {
308 Box::new(BufWriter::new(std::fs::File::create(path)?))
309 } else {
310 Box::new(std::io::stdout())
311 };
312
313 let mut num_repaired = 0;
314 let mut num_repair_errors = 0;
315
316 for (idx, example) in examples.iter_mut().enumerate() {
317 // Add repair prediction if we have a result
318 if let Some(response_text) = results_by_idx.get(&idx) {
319 match parse_repair_response(example, response_text) {
320 Ok(prediction) => {
321 example.predictions.push(prediction);
322 num_repaired += 1;
323 }
324 Err(e) => {
325 // Add error prediction
326 example.predictions.push(ExamplePrediction {
327 actual_patch: None,
328 actual_output: response_text.clone(),
329 error: Some(format!("Failed to parse repair response: {}", e)),
330 provider: PredictionProvider::Repair,
331 });
332 num_repair_errors += 1;
333 }
334 }
335 }
336
337 writeln!(writer, "{}", serde_json::to_string(&example)?)?;
338 }
339
340 if let Some(path) = output_path {
341 eprintln!("Results written to {}", path.display());
342 }
343
344 eprintln!("Repaired: {} items", num_repaired);
345 eprintln!("Repair errors: {} items", num_repair_errors);
346
347 Ok(())
348}