repair.rs

  1//! Repair predictions that received poor QA scores.
  2//!
  3//! This module takes examples with predictions and QA feedback, identifies
  4//! predictions that need improvement (based on reverts_edits or low confidence),
  5//! and uses an LLM to generate improved predictions.
  6
  7use crate::BatchProvider;
  8use crate::PredictionProvider;
  9use crate::anthropic_client::AnthropicClient;
 10use crate::example::{Example, ExamplePrediction};
 11use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
 12use crate::openai_client::OpenAiClient;
 13use crate::paths::LLM_CACHE_DB;
 14use crate::word_diff::unified_to_word_diff;
 15use anyhow::Result;
 16use std::io::{BufWriter, Write};
 17use std::path::PathBuf;
 18
 19const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md");
 20
 21/// Arguments for the repair command.
 22#[derive(Debug, Clone, clap::Args)]
 23pub struct RepairArgs {
 24    /// Use synchronous API instead of batch
 25    #[clap(long)]
 26    pub no_batch: bool,
 27
 28    /// Wait for batch to complete (polls every 30s)
 29    #[clap(long)]
 30    pub wait: bool,
 31
 32    /// Confidence threshold: repair predictions with confidence <= this value (1-5)
 33    #[clap(long, default_value = "2")]
 34    pub confidence_threshold: u8,
 35
 36    /// Which LLM provider to use (anthropic or openai)
 37    #[clap(long, default_value = "anthropic")]
 38    pub backend: BatchProvider,
 39}
 40
 41fn model_for_backend(backend: BatchProvider) -> &'static str {
 42    match backend {
 43        BatchProvider::Anthropic => "claude-sonnet-4-5",
 44        BatchProvider::Openai => "gpt-5.2",
 45    }
 46}
 47
 48/// Build the repair prompt for an example that needs improvement.
 49///
 50/// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs).
 51pub fn build_repair_prompt(example: &Example) -> Option<String> {
 52    let prediction = example.predictions.first()?;
 53    let qa = example.qa.first()?.as_ref()?;
 54    let prompt_inputs = example.prompt_inputs.as_ref()?;
 55    let actual_patch = prediction.actual_patch.as_ref()?;
 56
 57    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 58
 59    // Format edit history similar to qa.rs
 60    let mut edit_history = String::new();
 61    for event in &prompt_inputs.edit_history {
 62        match event.as_ref() {
 63            zeta_prompt::Event::BufferChange {
 64                path,
 65                old_path,
 66                diff,
 67                predicted: _,
 68                in_open_source_repo: _,
 69            } => {
 70                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 71                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 72                let diff_word_diff = unified_to_word_diff(diff);
 73                edit_history.push_str(&diff_word_diff);
 74                edit_history.push_str("\n\n");
 75            }
 76        }
 77    }
 78
 79    // Format related files context (reuse from TeacherPrompt)
 80    let context = TeacherPrompt::format_context(example);
 81
 82    // Format cursor excerpt with editable region markers (reuse from format_prompt)
 83    let cursor_excerpt = extract_cursor_excerpt_from_example(example)?;
 84
 85    // Get QA feedback
 86    let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
 87    let reverts_edits = qa
 88        .reverts_edits
 89        .map_or("unknown", |v| if v { "yes" } else { "no" });
 90    let confidence = qa
 91        .confidence
 92        .map_or("unknown".to_string(), |v| v.to_string());
 93
 94    Some(
 95        PROMPT_TEMPLATE
 96            .replace("{edit_history}", &edit_history)
 97            .replace("{context}", &context)
 98            .replace("{cursor_excerpt}", &cursor_excerpt)
 99            .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
100            .replace("{reverts_edits}", reverts_edits)
101            .replace("{confidence}", &confidence)
102            .replace("{qa_reasoning}", qa_reasoning),
103    )
104}
105
106/// Check if an example needs repair based on QA feedback.
107pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
108    let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) else {
109        return false;
110    };
111
112    // Repair if reverts_edits is true
113    if qa.reverts_edits == Some(true) {
114        return true;
115    }
116
117    // Repair if confidence is at or below threshold
118    if let Some(confidence) = qa.confidence {
119        if confidence <= confidence_threshold {
120            return true;
121        }
122    }
123
124    false
125}
126
127/// Parse the repair response into a prediction.
128fn parse_repair_response(example: &Example, response_text: &str) -> Result<ExamplePrediction> {
129    let actual_patch = TeacherPrompt::parse(example, response_text)?;
130
131    Ok(ExamplePrediction {
132        actual_patch: Some(actual_patch),
133        actual_output: response_text.to_string(),
134        error: None,
135        provider: PredictionProvider::Repair,
136    })
137}
138
139enum RepairClient {
140    Anthropic(AnthropicClient),
141    OpenAi(OpenAiClient),
142}
143
144impl RepairClient {
145    async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
146        match self {
147            RepairClient::Anthropic(client) => {
148                let messages = vec![anthropic::Message {
149                    role: anthropic::Role::User,
150                    content: vec![anthropic::RequestContent::Text {
151                        text: prompt.to_string(),
152                        cache_control: None,
153                    }],
154                }];
155                let response = client.generate(model, max_tokens, messages, None).await?;
156                Ok(response.map(|r| {
157                    r.content
158                        .iter()
159                        .filter_map(|c| match c {
160                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
161                            _ => None,
162                        })
163                        .collect::<Vec<_>>()
164                        .join("")
165                }))
166            }
167            RepairClient::OpenAi(client) => {
168                let messages = vec![open_ai::RequestMessage::User {
169                    content: open_ai::MessageContent::Plain(prompt.to_string()),
170                }];
171                let response = client.generate(model, max_tokens, messages, None).await?;
172                Ok(response.map(|r| {
173                    r.choices
174                        .into_iter()
175                        .filter_map(|choice| match choice.message {
176                            open_ai::RequestMessage::Assistant { content, .. } => {
177                                content.map(|c| match c {
178                                    open_ai::MessageContent::Plain(text) => text,
179                                    open_ai::MessageContent::Multipart(parts) => parts
180                                        .into_iter()
181                                        .filter_map(|p| match p {
182                                            open_ai::MessagePart::Text { text } => Some(text),
183                                            _ => None,
184                                        })
185                                        .collect::<Vec<_>>()
186                                        .join(""),
187                                })
188                            }
189                            _ => None,
190                        })
191                        .collect::<Vec<_>>()
192                        .join("")
193                }))
194            }
195        }
196    }
197
198    async fn sync_batches(&self) -> Result<()> {
199        match self {
200            RepairClient::Anthropic(client) => client.sync_batches().await,
201            RepairClient::OpenAi(client) => client.sync_batches().await,
202        }
203    }
204}
205
206/// Run the repair process on a set of examples.
207pub async fn run_repair(
208    examples: &mut [Example],
209    args: &RepairArgs,
210    output_path: Option<&PathBuf>,
211) -> Result<()> {
212    let model = model_for_backend(args.backend);
213    let client = match args.backend {
214        BatchProvider::Anthropic => {
215            if args.no_batch {
216                RepairClient::Anthropic(AnthropicClient::plain()?)
217            } else {
218                RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
219            }
220        }
221        BatchProvider::Openai => {
222            if args.no_batch {
223                RepairClient::OpenAi(OpenAiClient::plain()?)
224            } else {
225                RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
226            }
227        }
228    };
229
230    eprintln!(
231        "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}",
232        model, args.backend, !args.no_batch, args.confidence_threshold
233    );
234
235    // First pass: identify examples that need repair and build prompts
236    let mut repair_items: Vec<(usize, String)> = Vec::new();
237    let mut skipped_missing_data = 0;
238    let mut skipped_no_repair_needed = 0;
239
240    for (idx, example) in examples.iter().enumerate() {
241        // Skip if missing predictions or qa
242        if example.predictions.is_empty() || example.qa.is_empty() {
243            skipped_missing_data += 1;
244            continue;
245        }
246
247        // Skip if doesn't need repair
248        if !needs_repair(example, args.confidence_threshold) {
249            skipped_no_repair_needed += 1;
250            continue;
251        }
252
253        // Build repair prompt
254        let Some(prompt) = build_repair_prompt(example) else {
255            skipped_missing_data += 1;
256            continue;
257        };
258
259        repair_items.push((idx, prompt));
260    }
261
262    eprintln!(
263        "Skipping {} items with missing data, {} items that don't need repair",
264        skipped_missing_data, skipped_no_repair_needed
265    );
266    eprintln!("{} items to repair", repair_items.len());
267
268    // Process all items
269    let mut results: Vec<(usize, Option<String>)> = Vec::new();
270
271    if args.no_batch {
272        // Synchronous processing
273        for (i, (idx, prompt)) in repair_items.iter().enumerate() {
274            eprint!("\rProcessing {}/{}", i + 1, repair_items.len());
275
276            let response = client.generate(model, 16384, prompt).await?;
277            results.push((*idx, response));
278        }
279        eprintln!();
280    } else {
281        // Queue all for batching
282        for (idx, prompt) in &repair_items {
283            let response = client.generate(model, 16384, prompt).await?;
284            results.push((*idx, response));
285        }
286
287        // Sync batches (upload pending, download finished)
288        client.sync_batches().await?;
289
290        if args.wait {
291            eprintln!("Waiting for batch to complete...");
292            loop {
293                std::thread::sleep(std::time::Duration::from_secs(30));
294                client.sync_batches().await?;
295
296                // Re-check all items that didn't have results
297                let mut all_done = true;
298                for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() {
299                    if results[result_idx].1.is_none() {
300                        let response = client.generate(model, 16384, prompt).await?;
301                        if let Some(text) = response {
302                            results[result_idx] = (*idx, Some(text));
303                        } else {
304                            all_done = false;
305                        }
306                    }
307                }
308
309                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
310                if all_done {
311                    break;
312                }
313                eprintln!(
314                    "Still waiting... {}/{} results",
315                    done_count,
316                    repair_items.len()
317                );
318            }
319        } else {
320            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
321            if pending_count > 0 {
322                eprintln!(
323                    "Batch submitted. {} pending. Run again later to retrieve results.",
324                    pending_count
325                );
326            }
327        }
328    }
329
330    // Build results map by index
331    let mut results_by_idx: std::collections::HashMap<usize, String> =
332        std::collections::HashMap::new();
333    for (idx, result) in results {
334        if let Some(r) = result {
335            results_by_idx.insert(idx, r);
336        }
337    }
338
339    // Output results
340    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
341        Box::new(BufWriter::new(std::fs::File::create(path)?))
342    } else {
343        Box::new(std::io::stdout())
344    };
345
346    let mut num_repaired = 0;
347    let mut num_repair_errors = 0;
348
349    for (idx, example) in examples.iter_mut().enumerate() {
350        // Add repair prediction if we have a result
351        if let Some(response_text) = results_by_idx.get(&idx) {
352            match parse_repair_response(example, response_text) {
353                Ok(prediction) => {
354                    example.predictions.push(prediction);
355                    num_repaired += 1;
356                }
357                Err(e) => {
358                    // Add error prediction
359                    example.predictions.push(ExamplePrediction {
360                        actual_patch: None,
361                        actual_output: response_text.clone(),
362                        error: Some(format!("Failed to parse repair response: {}", e)),
363                        provider: PredictionProvider::Repair,
364                    });
365                    num_repair_errors += 1;
366                }
367            }
368        }
369
370        writeln!(writer, "{}", serde_json::to_string(&example)?)?;
371    }
372
373    if let Some(path) = output_path {
374        eprintln!("Results written to {}", path.display());
375    }
376
377    eprintln!("Repaired:      {} items", num_repaired);
378    eprintln!("Repair errors: {} items", num_repair_errors);
379
380    Ok(())
381}