repair.rs

  1//! Repair predictions that received poor QA scores.
  2//!
  3//! This module takes examples with predictions and QA feedback, identifies
  4//! predictions that need improvement (based on reverts_edits or low confidence),
  5//! and uses an LLM to generate improved predictions.
  6
  7use crate::PredictionProvider;
  8use crate::anthropic_client::AnthropicClient;
  9use crate::example::{Example, ExamplePrediction};
 10use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
 11use crate::paths::LLM_CACHE_DB;
 12use crate::word_diff::unified_to_word_diff;
 13use anthropic::{Message, RequestContent, Role};
 14use anyhow::Result;
 15use std::io::{BufWriter, Write};
 16use std::path::PathBuf;
 17
 18/// Model to use for repair.
 19const MODEL: &str = "claude-sonnet-4-5";
 20
 21const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md");
 22
 23/// Arguments for the repair command.
 24#[derive(Debug, Clone, clap::Args)]
 25pub struct RepairArgs {
 26    /// Use synchronous API instead of batch
 27    #[clap(long)]
 28    pub no_batch: bool,
 29
 30    /// Wait for batch to complete (polls every 30s)
 31    #[clap(long)]
 32    pub wait: bool,
 33
 34    /// Confidence threshold: repair predictions with confidence <= this value (1-5)
 35    #[clap(long, default_value = "2")]
 36    pub confidence_threshold: u8,
 37}
 38
 39/// Build the repair prompt for an example that needs improvement.
 40///
 41/// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs).
 42pub fn build_repair_prompt(example: &Example) -> Option<String> {
 43    let prediction = example.predictions.first()?;
 44    let qa = example.qa.first()?.as_ref()?;
 45    let prompt_inputs = example.prompt_inputs.as_ref()?;
 46    let actual_patch = prediction.actual_patch.as_ref()?;
 47
 48    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 49
 50    // Format edit history similar to qa.rs
 51    let mut edit_history = String::new();
 52    for event in &prompt_inputs.edit_history {
 53        match event.as_ref() {
 54            zeta_prompt::Event::BufferChange {
 55                path,
 56                old_path,
 57                diff,
 58                predicted: _,
 59                in_open_source_repo: _,
 60            } => {
 61                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 62                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 63                let diff_word_diff = unified_to_word_diff(diff);
 64                edit_history.push_str(&diff_word_diff);
 65                edit_history.push_str("\n\n");
 66            }
 67        }
 68    }
 69
 70    // Format related files context (reuse from TeacherPrompt)
 71    let context = TeacherPrompt::format_context(example);
 72
 73    // Format cursor excerpt with editable region markers (reuse from format_prompt)
 74    let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
 75
 76    // Get QA feedback
 77    let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
 78    let reverts_edits = qa
 79        .reverts_edits
 80        .map_or("unknown", |v| if v { "yes" } else { "no" });
 81    let confidence = qa
 82        .confidence
 83        .map_or("unknown".to_string(), |v| v.to_string());
 84
 85    Some(
 86        PROMPT_TEMPLATE
 87            .replace("{edit_history}", &edit_history)
 88            .replace("{context}", &context)
 89            .replace("{cursor_excerpt}", &cursor_excerpt)
 90            .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
 91            .replace("{reverts_edits}", reverts_edits)
 92            .replace("{confidence}", &confidence)
 93            .replace("{qa_reasoning}", qa_reasoning),
 94    )
 95}
 96
 97/// Check if an example needs repair based on QA feedback.
 98pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
 99    let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) else {
100        return false;
101    };
102
103    // Repair if reverts_edits is true
104    if qa.reverts_edits == Some(true) {
105        return true;
106    }
107
108    // Repair if confidence is at or below threshold
109    if let Some(confidence) = qa.confidence {
110        if confidence <= confidence_threshold {
111            return true;
112        }
113    }
114
115    false
116}
117
118/// Parse the repair response into a prediction.
119fn parse_repair_response(example: &Example, response_text: &str) -> Result<ExamplePrediction> {
120    let actual_patch = TeacherPrompt::parse(example, response_text)?;
121
122    Ok(ExamplePrediction {
123        actual_patch: Some(actual_patch),
124        actual_output: response_text.to_string(),
125        error: None,
126        provider: PredictionProvider::Repair,
127    })
128}
129
130/// Run the repair process on a set of examples.
131pub async fn run_repair(
132    examples: &mut [Example],
133    args: &RepairArgs,
134    output_path: Option<&PathBuf>,
135) -> Result<()> {
136    let client = if args.no_batch {
137        AnthropicClient::plain()?
138    } else {
139        AnthropicClient::batch(&LLM_CACHE_DB)?
140    };
141
142    eprintln!(
143        "Using model: {}, batching: {}, confidence_threshold: {}",
144        MODEL, !args.no_batch, args.confidence_threshold
145    );
146
147    // First pass: identify examples that need repair and build prompts
148    let mut repair_items: Vec<(usize, String)> = Vec::new();
149    let mut skipped_missing_data = 0;
150    let mut skipped_no_repair_needed = 0;
151
152    for (idx, example) in examples.iter().enumerate() {
153        // Skip if missing predictions or qa
154        if example.predictions.is_empty() || example.qa.is_empty() {
155            skipped_missing_data += 1;
156            continue;
157        }
158
159        // Skip if doesn't need repair
160        if !needs_repair(example, args.confidence_threshold) {
161            skipped_no_repair_needed += 1;
162            continue;
163        }
164
165        // Build repair prompt
166        let Some(prompt) = build_repair_prompt(example) else {
167            skipped_missing_data += 1;
168            continue;
169        };
170
171        repair_items.push((idx, prompt));
172    }
173
174    eprintln!(
175        "Skipping {} items with missing data, {} items that don't need repair",
176        skipped_missing_data, skipped_no_repair_needed
177    );
178    eprintln!("{} items to repair", repair_items.len());
179
180    // Process all items
181    let mut results: Vec<(usize, Option<String>)> = Vec::new();
182
183    if args.no_batch {
184        // Synchronous processing
185        for (i, (idx, prompt)) in repair_items.iter().enumerate() {
186            eprint!("\rProcessing {}/{}", i + 1, repair_items.len());
187
188            let messages = vec![Message {
189                role: Role::User,
190                content: vec![RequestContent::Text {
191                    text: prompt.clone(),
192                    cache_control: None,
193                }],
194            }];
195
196            let response = client.generate(MODEL, 16384, messages).await?;
197            let result = response.map(|r| {
198                r.content
199                    .iter()
200                    .filter_map(|c| match c {
201                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
202                        _ => None,
203                    })
204                    .collect::<Vec<_>>()
205                    .join("")
206            });
207            results.push((*idx, result));
208        }
209        eprintln!();
210    } else {
211        // Queue all for batching
212        for (idx, prompt) in &repair_items {
213            let messages = vec![Message {
214                role: Role::User,
215                content: vec![RequestContent::Text {
216                    text: prompt.clone(),
217                    cache_control: None,
218                }],
219            }];
220
221            let response = client.generate(MODEL, 16384, messages).await?;
222            let result = response.map(|r| {
223                r.content
224                    .iter()
225                    .filter_map(|c| match c {
226                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
227                        _ => None,
228                    })
229                    .collect::<Vec<_>>()
230                    .join("")
231            });
232            results.push((*idx, result));
233        }
234
235        // Sync batches (upload pending, download finished)
236        client.sync_batches().await?;
237
238        if args.wait {
239            eprintln!("Waiting for batch to complete...");
240            loop {
241                std::thread::sleep(std::time::Duration::from_secs(30));
242                client.sync_batches().await?;
243
244                // Re-check all items that didn't have results
245                let mut all_done = true;
246                for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() {
247                    if results[result_idx].1.is_none() {
248                        let messages = vec![Message {
249                            role: Role::User,
250                            content: vec![RequestContent::Text {
251                                text: prompt.clone(),
252                                cache_control: None,
253                            }],
254                        }];
255
256                        let response = client.generate(MODEL, 16384, messages).await?;
257                        if let Some(r) = response {
258                            let text = r
259                                .content
260                                .iter()
261                                .filter_map(|c| match c {
262                                    anthropic::ResponseContent::Text { text } => {
263                                        Some(text.as_str())
264                                    }
265                                    _ => None,
266                                })
267                                .collect::<Vec<_>>()
268                                .join("");
269                            results[result_idx] = (*idx, Some(text));
270                        } else {
271                            all_done = false;
272                        }
273                    }
274                }
275
276                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
277                if all_done {
278                    break;
279                }
280                eprintln!(
281                    "Still waiting... {}/{} results",
282                    done_count,
283                    repair_items.len()
284                );
285            }
286        } else {
287            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
288            if pending_count > 0 {
289                eprintln!(
290                    "Batch submitted. {} pending. Run again later to retrieve results.",
291                    pending_count
292                );
293            }
294        }
295    }
296
297    // Build results map by index
298    let mut results_by_idx: std::collections::HashMap<usize, String> =
299        std::collections::HashMap::new();
300    for (idx, result) in results {
301        if let Some(r) = result {
302            results_by_idx.insert(idx, r);
303        }
304    }
305
306    // Output results
307    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
308        Box::new(BufWriter::new(std::fs::File::create(path)?))
309    } else {
310        Box::new(std::io::stdout())
311    };
312
313    let mut num_repaired = 0;
314    let mut num_repair_errors = 0;
315
316    for (idx, example) in examples.iter_mut().enumerate() {
317        // Add repair prediction if we have a result
318        if let Some(response_text) = results_by_idx.get(&idx) {
319            match parse_repair_response(example, response_text) {
320                Ok(prediction) => {
321                    example.predictions.push(prediction);
322                    num_repaired += 1;
323                }
324                Err(e) => {
325                    // Add error prediction
326                    example.predictions.push(ExamplePrediction {
327                        actual_patch: None,
328                        actual_output: response_text.clone(),
329                        error: Some(format!("Failed to parse repair response: {}", e)),
330                        provider: PredictionProvider::Repair,
331                    });
332                    num_repair_errors += 1;
333                }
334            }
335        }
336
337        writeln!(writer, "{}", serde_json::to_string(&example)?)?;
338    }
339
340    if let Some(path) = output_path {
341        eprintln!("Results written to {}", path.display());
342    }
343
344    eprintln!("Repaired:      {} items", num_repaired);
345    eprintln!("Repair errors: {} items", num_repair_errors);
346
347    Ok(())
348}