repair.rs

  1//! Repair predictions that received poor quality signals.
  2//!
  3//! This module takes examples with predictions, identifies predictions that need
  4//! improvement, and uses an LLM to generate improved predictions. It supports
  5//! two sources of quality signals:
  6//! - QA feedback (reverts_edits or low confidence)
  7//! - Computed scores when QA is unavailable (high reversal_ratio or wrong_editable_region)
  8
  9use crate::{
 10    BatchProvider, PredictionProvider,
 11    anthropic_client::AnthropicClient,
 12    example::{ActualCursor, Example, ExamplePrediction},
 13    format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example, extract_last_codeblock},
 14    openai_client::OpenAiClient,
 15    parse_output::run_parse_output,
 16    paths::LLM_CACHE_DB,
 17    progress::{ExampleProgress, Step},
 18    word_diff::unified_to_word_diff,
 19};
 20use anyhow::{Context as _, Result};
 21use std::sync::OnceLock;
 22
 23const KEEP_PREVIOUS: &str = "KEEP_PREVIOUS";
 24
 25/// Print a summary report of repair results across all examples.
 26pub fn print_report(examples: &[Example], confidence_threshold: u8) {
 27    let total = examples.len();
 28    let mut no_repair_needed = 0;
 29    let mut repaired = 0;
 30    let mut repair_failed = 0;
 31
 32    for example in examples {
 33        if !needs_repair(example, confidence_threshold) {
 34            no_repair_needed += 1;
 35            continue;
 36        }
 37
 38        if has_successful_repair(example) {
 39            repaired += 1;
 40        } else {
 41            repair_failed += 1;
 42        }
 43    }
 44
 45    let needed_repair = total - no_repair_needed;
 46
 47    eprintln!();
 48    eprintln!("Repair summary ({total} examples):");
 49    eprintln!(
 50        "  {no_repair_needed}/{total} didn't need repair (confidence > {confidence_threshold})"
 51    );
 52    if needed_repair > 0 {
 53        eprintln!("  {needed_repair}/{total} needed repair:");
 54        if repaired > 0 {
 55            eprintln!("    {repaired} repaired successfully");
 56        }
 57        if repair_failed > 0 {
 58            eprintln!("    {repair_failed} failed to repair");
 59        }
 60    }
 61}
 62
 63/// Arguments for the repair command.
 64#[derive(Debug, Clone, clap::Args)]
 65pub struct RepairArgs {
 66    /// Use synchronous API instead of batch
 67    #[clap(long)]
 68    pub no_batch: bool,
 69
 70    /// Confidence threshold: repair predictions with confidence <= this value (1-5)
 71    #[clap(long, default_value = "2")]
 72    pub confidence_threshold: u8,
 73
 74    /// Which LLM provider to use (anthropic or openai)
 75    #[clap(long, default_value = "anthropic")]
 76    pub backend: BatchProvider,
 77}
 78
 79fn model_for_backend(backend: BatchProvider) -> &'static str {
 80    match backend {
 81        BatchProvider::Anthropic => "claude-sonnet-4-5",
 82        BatchProvider::Openai => "gpt-5.2",
 83    }
 84}
 85
 86/// Build the quality feedback string from QA results.
 87fn build_qa_feedback(example: &Example) -> Option<String> {
 88    let qa = example.qa.first()?.as_ref()?;
 89
 90    let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
 91    let reverts_edits = qa
 92        .reverts_edits
 93        .map_or("unknown", |v| if v { "yes" } else { "no" });
 94    let confidence = qa
 95        .confidence
 96        .map_or("unknown".to_string(), |v| v.to_string());
 97
 98    Some(format!(
 99        "- **Reverts user edits**: {reverts_edits}\n\
100         - **Confidence score**: {confidence}/5\n\
101         - **Reasoning**: {qa_reasoning}"
102    ))
103}
104
105/// Build the quality feedback string from computed scores when QA is unavailable.
106fn build_score_feedback(example: &Example) -> Option<String> {
107    let score = example.score.first()?;
108
109    let mut issues = Vec::new();
110
111    if score.reversal_ratio > 0.9 {
112        issues.push(format!(
113            "Automated analysis detected a high reversal ratio ({:.2}), which suggests this \
114             prediction may be reverting changes the user intentionally made. Double-check that \
115             the prediction doesn't undo the user's recent edits. If the prediction is actually \
116             fine and the edits are intentional completions rather than reversals, keep it as-is. \
117             If it truly reverts the user's changes, generate an improved prediction that \
118             continues the user's intent instead.",
119            score.reversal_ratio
120        ));
121    }
122
123    if score.wrong_editable_region == Some(true) {
124        issues.push(
125            "Automated analysis detected that the prediction may be modifying code outside \
126             the expected editable region, or producing changes misaligned with the editable \
127             region boundaries. Make sure the prediction only modifies code within the editable \
128             region and is properly aligned."
129                .to_string(),
130        );
131    }
132
133    if issues.is_empty() {
134        return None;
135    }
136
137    let mut feedback = String::from(
138        "No human quality assessment is available, but automated scoring flagged potential issues:\n\n",
139    );
140    for issue in &issues {
141        feedback.push_str(&format!("- {issue}\n"));
142    }
143    feedback.push_str(
144        "\nRemember: if the previous prediction was actually correct, output `KEEP_PREVIOUS`. \
145         If no edits should be made at all and you are unsure how to improve it, output `NO_EDITS`.",
146    );
147
148    Some(feedback)
149}
150
151/// Build the repair prompt for an example that needs improvement.
152pub fn build_repair_prompt(example: &Example) -> Result<String> {
153    let prediction = example
154        .predictions
155        .first()
156        .context("no predictions available")?;
157    let prompt_inputs = example
158        .prompt_inputs
159        .as_ref()
160        .context("prompt_inputs missing (run context retrieval first)")?;
161    let actual_patch = prediction
162        .actual_patch
163        .as_ref()
164        .context("no actual_patch available (run predict first)")?;
165
166    let quality_feedback = build_qa_feedback(example)
167        .or_else(|| build_score_feedback(example))
168        .context("no quality feedback available (need either QA results or computed scores)")?;
169
170    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
171
172    let mut edit_history = String::new();
173    for event in &prompt_inputs.edit_history {
174        match event.as_ref() {
175            zeta_prompt::Event::BufferChange {
176                path,
177                old_path,
178                diff,
179                predicted: _,
180                in_open_source_repo: _,
181            } => {
182                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
183                edit_history.push_str(&format!("+++ b{}\n", path.display()));
184                let diff_word_diff = unified_to_word_diff(diff);
185                edit_history.push_str(&diff_word_diff);
186                edit_history.push_str("\n\n");
187            }
188        }
189    }
190
191    let context = TeacherPrompt::format_context(example);
192
193    let cursor_excerpt =
194        extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?;
195
196    let prompt_template = crate::prompt_assets::get_prompt("repair.md");
197    Ok(prompt_template
198        .replace("{edit_history}", &edit_history)
199        .replace("{context}", &context)
200        .replace("{cursor_excerpt}", &cursor_excerpt)
201        .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
202        .replace("{quality_feedback}", &quality_feedback))
203}
204
205/// Check if an example needs repair based on QA feedback or computed scores.
206pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
207    // Check QA-based signals first.
208    if let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) {
209        if qa.reverts_edits == Some(true) {
210            return true;
211        }
212
213        if let Some(confidence) = qa.confidence {
214            if confidence <= confidence_threshold {
215                return true;
216            }
217        }
218
219        return false;
220    }
221
222    // When QA is unavailable, fall back to computed score signals.
223    if let Some(score) = example.score.first() {
224        if score.reversal_ratio > 0.9 {
225            return true;
226        }
227
228        if score.wrong_editable_region == Some(true) {
229            return true;
230        }
231    }
232
233    false
234}
235
236/// Parse repair model output into a patch and optional cursor.
237///
238/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
239/// and delegates normal output to `TeacherPrompt::parse`.
240pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
241    let last_codeblock = extract_last_codeblock(actual_output);
242    if last_codeblock.trim() == KEEP_PREVIOUS {
243        let original = example
244            .predictions
245            .first()
246            .context("no original prediction to keep")?;
247        let patch = original.actual_patch.clone().unwrap_or_default();
248        let cursor = original.actual_cursor.clone();
249        return Ok((patch, cursor));
250    }
251
252    TeacherPrompt::parse(example, actual_output)
253}
254
255/// Check if an example already has a successful repair prediction.
256fn has_successful_repair(example: &Example) -> bool {
257    example
258        .predictions
259        .iter()
260        .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
261}
262
263static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
264static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
265static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
266static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
267
268/// Run repair for a single example.
269pub async fn run_repair(
270    example: &mut Example,
271    args: &RepairArgs,
272    example_progress: &ExampleProgress,
273) -> Result<()> {
274    if has_successful_repair(example) {
275        return Ok(());
276    }
277
278    if !needs_repair(example, args.confidence_threshold) {
279        return Ok(());
280    }
281
282    run_parse_output(example).context("Failed to execute run_parse_output")?;
283
284    if example.prompt_inputs.is_none() {
285        anyhow::bail!("prompt_inputs missing (run context retrieval first)");
286    }
287
288    if example.predictions.is_empty() {
289        anyhow::bail!("no predictions available (run predict first)");
290    }
291
292    let step_progress = example_progress.start(Step::Repair);
293
294    let model = model_for_backend(args.backend);
295    let prompt = build_repair_prompt(example).context("Failed to build repair prompt")?;
296
297    step_progress.set_substatus("generating");
298
299    let response = match args.backend {
300        BatchProvider::Anthropic => {
301            let client = if args.no_batch {
302                ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
303                    AnthropicClient::plain().expect("Failed to create Anthropic client")
304                })
305            } else {
306                ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
307                    AnthropicClient::batch(&LLM_CACHE_DB)
308                        .expect("Failed to create Anthropic client")
309                })
310            };
311
312            let messages = vec![anthropic::Message {
313                role: anthropic::Role::User,
314                content: vec![anthropic::RequestContent::Text {
315                    text: prompt,
316                    cache_control: None,
317                }],
318            }];
319
320            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
321                return Ok(());
322            };
323
324            response
325                .content
326                .iter()
327                .filter_map(|c| match c {
328                    anthropic::ResponseContent::Text { text } => Some(text.as_str()),
329                    _ => None,
330                })
331                .collect::<Vec<_>>()
332                .join("")
333        }
334        BatchProvider::Openai => {
335            let client = if args.no_batch {
336                OPENAI_CLIENT_PLAIN
337                    .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
338            } else {
339                OPENAI_CLIENT_BATCH.get_or_init(|| {
340                    OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
341                })
342            };
343
344            let messages = vec![open_ai::RequestMessage::User {
345                content: open_ai::MessageContent::Plain(prompt),
346            }];
347
348            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
349                return Ok(());
350            };
351
352            response
353                .choices
354                .into_iter()
355                .filter_map(|choice| match choice.message {
356                    open_ai::RequestMessage::Assistant { content, .. } => {
357                        content.map(|c| match c {
358                            open_ai::MessageContent::Plain(text) => text,
359                            open_ai::MessageContent::Multipart(parts) => parts
360                                .into_iter()
361                                .filter_map(|p| match p {
362                                    open_ai::MessagePart::Text { text } => Some(text),
363                                    _ => None,
364                                })
365                                .collect::<Vec<_>>()
366                                .join(""),
367                        })
368                    }
369                    _ => None,
370                })
371                .collect::<Vec<_>>()
372                .join("")
373        }
374    };
375
376    let parse_result = parse(example, &response);
377    let err = parse_result
378        .as_ref()
379        .err()
380        .map(|e| format!("Failed to parse repair response: {}", e));
381
382    let (actual_patch, actual_cursor) = parse_result.ok().unzip();
383    let actual_cursor = actual_cursor.flatten();
384
385    example.predictions.push(ExamplePrediction {
386        actual_patch,
387        actual_output: response,
388        actual_cursor,
389        error: err,
390        provider: PredictionProvider::Repair,
391    });
392
393    Ok(())
394}
395
396/// Sync batches for repair (upload pending requests, download finished results).
397pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
398    if args.no_batch {
399        return Ok(());
400    }
401
402    match args.backend {
403        BatchProvider::Anthropic => {
404            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
405                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
406            });
407            client.sync_batches().await?;
408        }
409        BatchProvider::Openai => {
410            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
411                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
412            });
413            client.sync_batches().await?;
414        }
415    }
416
417    Ok(())
418}