repair.rs

  1//! Repair predictions that received poor quality signals.
  2//!
  3//! This module takes examples with predictions, identifies predictions that need
  4//! improvement, and uses an LLM to generate improved predictions. It supports
  5//! two sources of quality signals:
  6//! - QA feedback (reverts_edits or low confidence)
  7//! - Computed scores when QA is unavailable (high reversal_ratio or wrong_editable_region)
  8
  9use crate::{
 10    BatchProvider, PredictionProvider,
 11    anthropic_client::AnthropicClient,
 12    example::{ActualCursor, Example, ExamplePrediction},
 13    format_prompt::TeacherPrompt,
 14    metrics::count_patch_token_changes,
 15    openai_client::OpenAiClient,
 16    parse_output::run_parse_output,
 17    paths::LLM_CACHE_DB,
 18    progress::{ExampleProgress, Progress, Step},
 19    word_diff::unified_to_word_diff,
 20};
 21use anyhow::{Context as _, Result};
 22use std::sync::OnceLock;
 23
 24const KEEP_PREVIOUS: &str = "KEEP_PREVIOUS";
 25
 26/// Print a summary report of repair results across all examples.
 27pub fn print_report(examples: &[Example], confidence_threshold: u8) {
 28    let total = examples.len();
 29    let mut no_repair_needed = 0;
 30    let mut repaired = 0;
 31    let mut repair_failed = 0;
 32
 33    for example in examples {
 34        if !needs_repair(example, confidence_threshold) {
 35            no_repair_needed += 1;
 36            continue;
 37        }
 38
 39        if has_successful_repair(example) {
 40            repaired += 1;
 41        } else {
 42            repair_failed += 1;
 43        }
 44    }
 45
 46    let needed_repair = total - no_repair_needed;
 47
 48    eprintln!();
 49    eprintln!("Repair summary ({total} examples):");
 50    eprintln!(
 51        "  {no_repair_needed}/{total} didn't need repair (confidence > {confidence_threshold})"
 52    );
 53    if needed_repair > 0 {
 54        eprintln!("  {needed_repair}/{total} needed repair:");
 55        if repaired > 0 {
 56            eprintln!("    {repaired} repaired successfully");
 57        }
 58        if repair_failed > 0 {
 59            eprintln!("    {repair_failed} failed to repair");
 60        }
 61    }
 62}
 63
 64/// Arguments for the repair command.
 65#[derive(Debug, Clone, clap::Args)]
 66pub struct RepairArgs {
 67    /// Use synchronous API instead of batch
 68    #[clap(long)]
 69    pub no_batch: bool,
 70
 71    /// Confidence threshold: repair predictions with confidence <= this value (1-5)
 72    #[clap(long, default_value = "2")]
 73    pub confidence_threshold: u8,
 74
 75    /// Which LLM provider to use (anthropic or openai)
 76    #[clap(long, default_value = "anthropic")]
 77    pub backend: BatchProvider,
 78    /// Wait for all batches to complete before exiting
 79    #[clap(long)]
 80    pub wait: bool,
 81}
 82
 83fn model_for_backend(backend: BatchProvider) -> &'static str {
 84    match backend {
 85        BatchProvider::Anthropic => "claude-sonnet-4-6",
 86        BatchProvider::Openai => "gpt-5.2",
 87    }
 88}
 89
 90/// Build the quality feedback string from QA results.
 91fn build_qa_feedback(example: &Example) -> Option<String> {
 92    let qa = example.qa.first()?.as_ref()?;
 93
 94    let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
 95    let reverts_edits = qa
 96        .reverts_edits
 97        .map_or("unknown", |v| if v { "yes" } else { "no" });
 98    let confidence = qa
 99        .confidence
100        .map_or("unknown".to_string(), |v| v.to_string());
101
102    Some(format!(
103        "- **Reverts user edits**: {reverts_edits}\n\
104         - **Confidence score**: {confidence}/5\n\
105         - **Reasoning**: {qa_reasoning}"
106    ))
107}
108
109/// Build the quality feedback string from computed scores when QA is unavailable.
110fn build_score_feedback(example: &Example) -> Option<String> {
111    let score = example.score.first()?;
112
113    let mut issues = Vec::new();
114
115    if score.reversal_ratio > 0.9 {
116        issues.push(format!(
117            "Automated analysis detected a high reversal ratio ({:.2}), which suggests this \
118             prediction may be reverting changes the user intentionally made. Double-check that \
119             the prediction doesn't undo the user's recent edits. If the prediction is actually \
120             fine and the edits are intentional completions rather than reversals, keep it as-is. \
121             If it truly reverts the user's changes, generate an improved prediction that \
122             continues the user's intent instead.",
123            score.reversal_ratio
124        ));
125    }
126
127    if score.wrong_editable_region == Some(true) {
128        issues.push(
129            "Automated analysis detected that the prediction may be modifying code outside \
130             the expected editable region, or producing changes misaligned with the editable \
131             region boundaries. Make sure the prediction only modifies code within the editable \
132             region and is properly aligned."
133                .to_string(),
134        );
135    }
136
137    if score.discarded_chars.unwrap_or(0) > 80 && score.exact_lines_fp > 5 {
138        issues.push(
139            "Automated analysis detected that this prediction might be too large or speculative. \
140            Please review it and think if we should keep it or generate a more focused prediction. \
141            Examples of more focused predictions: \
142            - Predicting a function outline but not its body. \
143            - Predicting only the first logical step and not speculating about further steps.
144            In general, the smaller the prediction you make, the higher the chance it will be correct."
145                .to_string(),
146        );
147    }
148
149    if issues.is_empty() {
150        return None;
151    }
152
153    let mut feedback = String::from(
154        "No human quality assessment is available, but automated scoring flagged potential issues:\n\n",
155    );
156    for issue in &issues {
157        feedback.push_str(&format!("- {issue}\n"));
158    }
159    feedback.push_str(
160        "\nRemember: if the previous prediction was actually correct, output `KEEP_PREVIOUS`. \
161         If no edits should be made at all and you are unsure how to improve it, output `NO_EDITS`.",
162    );
163
164    Some(feedback)
165}
166
167/// Build the repair message (Turn 3) for a multi-turn conversation.
168///
169/// This message is sent after the original teacher prompt (Turn 1) and
170/// teacher response (Turn 2) to request an improved prediction.
171pub fn build_repair_message(example: &Example) -> Result<String> {
172    let prediction = example
173        .predictions
174        .first()
175        .context("no predictions available")?;
176    let actual_patch = prediction
177        .actual_patch
178        .as_ref()
179        .context("no actual_patch available (run predict first)")?;
180
181    let quality_feedback = build_qa_feedback(example)
182        .or_else(|| build_score_feedback(example))
183        .context("no quality feedback available (need either QA results or computed scores)")?;
184
185    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
186
187    let token_counts = count_patch_token_changes(actual_patch);
188    let mut token_change_info = format!(
189        "\n## Token Change Statistics\n\n\
190         - **Deleted tokens**: {}\n\
191         - **Inserted tokens**: {}",
192        token_counts.deleted_tokens, token_counts.inserted_tokens,
193    );
194    if token_counts.deleted_tokens > 100 || token_counts.inserted_tokens > 100 {
195        token_change_info.push_str(
196            "\n\n> **Note:** The token change count is high. \
197             Consider producing a more scoped edit that targets only the lines \
198             that truly need to change, rather than rewriting large sections.",
199        );
200    }
201
202    let prompt_template = crate::prompt_assets::get_prompt("repair.md");
203    Ok(prompt_template
204        .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
205        .replace("{quality_feedback}", &quality_feedback)
206        .replace("{token_change_info}", &token_change_info))
207}
208
209/// Check if an example needs repair based on QA feedback or computed scores.
210pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
211    // Check QA-based signals first.
212    if let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) {
213        if qa.reverts_edits == Some(true) {
214            return true;
215        }
216
217        if let Some(confidence) = qa.confidence {
218            if confidence <= confidence_threshold {
219                return true;
220            }
221        }
222
223        return false;
224    }
225
226    // When QA is unavailable, fall back to computed score signals.
227    if let Some(score) = example.score.first() {
228        if score.reversal_ratio > 0.9 {
229            return true;
230        }
231
232        if score.wrong_editable_region == Some(true) {
233            return true;
234        }
235    }
236
237    false
238}
239
240/// Parse repair model output into a patch and optional cursor.
241///
242/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
243/// and delegates normal output to `TeacherPrompt::parse`.
244pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
245    if actual_output.contains(KEEP_PREVIOUS) {
246        let original = example
247            .predictions
248            .first()
249            .context("no original prediction to keep")?;
250        let patch = original.actual_patch.clone().unwrap_or_default();
251        let cursor = original.actual_cursor.clone();
252        return Ok((patch, cursor));
253    }
254
255    TeacherPrompt::parse(example, actual_output)
256}
257
258/// Check if an example already has a successful repair prediction.
259fn has_successful_repair(example: &Example) -> bool {
260    example
261        .predictions
262        .iter()
263        .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
264}
265
266static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
267static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
268static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
269static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
270
271/// Run repair for a single example.
272///
273/// This sends a multi-turn conversation to the LLM:
274/// - Turn 1 (User): Original teacher prompt
275/// - Turn 2 (Assistant): Original teacher response
276/// - Turn 3 (User): Repair critique and instructions
277/// - Turn 4 (Assistant): Improved prediction (the response we parse)
278pub async fn run_repair(
279    example: &mut Example,
280    args: &RepairArgs,
281    example_progress: &ExampleProgress,
282) -> Result<()> {
283    if has_successful_repair(example) {
284        return Ok(());
285    }
286
287    if !needs_repair(example, args.confidence_threshold) {
288        return Ok(());
289    }
290
291    run_parse_output(example).context("Failed to execute run_parse_output")?;
292
293    if example.prompt_inputs.is_none() {
294        anyhow::bail!("prompt_inputs missing (run context retrieval first)");
295    }
296
297    if example.predictions.is_empty() {
298        anyhow::bail!("no predictions available (run predict first)");
299    }
300
301    let teacher_prompt = example
302        .prompt
303        .as_ref()
304        .context("prompt missing (run format_prompt first)")?;
305
306    let teacher_response = &example.predictions[0].actual_output;
307    if teacher_response.is_empty() {
308        anyhow::bail!("teacher response is empty (run predict first)");
309    }
310
311    let step_progress = example_progress.start(Step::Repair);
312
313    let model = model_for_backend(args.backend);
314    let repair_message = build_repair_message(example).context("Failed to build repair message")?;
315
316    step_progress.set_substatus("generating");
317
318    let response = match args.backend {
319        BatchProvider::Anthropic => {
320            let client = if args.no_batch {
321                ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
322                    AnthropicClient::plain().expect("Failed to create Anthropic client")
323                })
324            } else {
325                ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
326                    AnthropicClient::batch(&LLM_CACHE_DB)
327                        .expect("Failed to create Anthropic client")
328                })
329            };
330
331            let messages = vec![
332                // Turn 1: Original teacher prompt
333                anthropic::Message {
334                    role: anthropic::Role::User,
335                    content: vec![anthropic::RequestContent::Text {
336                        text: teacher_prompt.input.clone(),
337                        cache_control: None,
338                    }],
339                },
340                // Turn 2: Original teacher response
341                anthropic::Message {
342                    role: anthropic::Role::Assistant,
343                    content: vec![anthropic::RequestContent::Text {
344                        text: teacher_response.clone(),
345                        cache_control: None,
346                    }],
347                },
348                // Turn 3: Repair critique and instructions
349                anthropic::Message {
350                    role: anthropic::Role::User,
351                    content: vec![anthropic::RequestContent::Text {
352                        text: repair_message,
353                        cache_control: None,
354                    }],
355                },
356            ];
357
358            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
359                return Ok(());
360            };
361
362            response
363                .content
364                .iter()
365                .filter_map(|c| match c {
366                    anthropic::ResponseContent::Text { text } => Some(text.as_str()),
367                    _ => None,
368                })
369                .collect::<Vec<_>>()
370                .join("")
371        }
372        BatchProvider::Openai => {
373            let client = if args.no_batch {
374                OPENAI_CLIENT_PLAIN
375                    .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
376            } else {
377                OPENAI_CLIENT_BATCH.get_or_init(|| {
378                    OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
379                })
380            };
381
382            let messages = vec![
383                // Turn 1: Original teacher prompt
384                open_ai::RequestMessage::User {
385                    content: open_ai::MessageContent::Plain(teacher_prompt.input.clone()),
386                },
387                // Turn 2: Original teacher response
388                open_ai::RequestMessage::Assistant {
389                    content: Some(open_ai::MessageContent::Plain(teacher_response.clone())),
390                    tool_calls: vec![],
391                    reasoning_content: None,
392                },
393                // Turn 3: Repair critique and instructions
394                open_ai::RequestMessage::User {
395                    content: open_ai::MessageContent::Plain(repair_message),
396                },
397            ];
398
399            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
400                return Ok(());
401            };
402
403            response
404                .choices
405                .into_iter()
406                .filter_map(|choice| match choice.message {
407                    open_ai::RequestMessage::Assistant { content, .. } => {
408                        content.map(|c| match c {
409                            open_ai::MessageContent::Plain(text) => text,
410                            open_ai::MessageContent::Multipart(parts) => parts
411                                .into_iter()
412                                .filter_map(|p| match p {
413                                    open_ai::MessagePart::Text { text } => Some(text),
414                                    _ => None,
415                                })
416                                .collect::<Vec<_>>()
417                                .join(""),
418                        })
419                    }
420                    _ => None,
421                })
422                .collect::<Vec<_>>()
423                .join("")
424        }
425    };
426
427    let parse_result = parse(example, &response);
428    let err = parse_result
429        .as_ref()
430        .err()
431        .map(|e| format!("Failed to parse repair response: {}", e));
432
433    let (actual_patch, actual_cursor) = parse_result.ok().unzip();
434    let actual_cursor = actual_cursor.flatten();
435
436    example.predictions.push(ExamplePrediction {
437        actual_patch,
438        actual_output: response,
439        actual_cursor,
440        error: err,
441        provider: PredictionProvider::Repair,
442        cumulative_logprob: None,
443        avg_logprob: None,
444    });
445
446    Ok(())
447}
448
449/// Sync batches for repair (upload pending requests, download finished results).
450pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
451    if args.no_batch {
452        return Ok(());
453    }
454
455    match args.backend {
456        BatchProvider::Anthropic => {
457            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
458                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
459            });
460            client.sync_batches().await?;
461        }
462        BatchProvider::Openai => {
463            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
464                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
465            });
466            client.sync_batches().await?;
467        }
468    }
469
470    Ok(())
471}
472
473pub async fn reprocess_after_batch_wait(examples: &mut [Example], args: &RepairArgs) -> Result<()> {
474    let mut reprocessed = 0;
475    for example in examples.iter_mut() {
476        if has_successful_repair(example) || !needs_repair(example, args.confidence_threshold) {
477            continue;
478        }
479
480        let example_progress = Progress::global().start_group(&example.spec.name);
481        run_repair(example, args, &example_progress).await?;
482        reprocessed += 1;
483    }
484
485    if reprocessed > 0 {
486        eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
487    }
488
489    Ok(())
490}
491
492pub async fn wait_for_batches(args: &RepairArgs) -> Result<()> {
493    if args.no_batch {
494        return Ok(());
495    }
496
497    let poll_interval = std::time::Duration::from_secs(30);
498
499    loop {
500        let pending = pending_batch_count(args)?;
501        if pending == 0 {
502            break;
503        }
504
505        eprintln!(
506            "Waiting for {} pending repair batch request(s) to complete... (polling every {}s)",
507            pending,
508            poll_interval.as_secs()
509        );
510        std::thread::sleep(poll_interval);
511
512        sync_batches(args).await?;
513    }
514
515    Ok(())
516}
517
518fn pending_batch_count(args: &RepairArgs) -> Result<usize> {
519    match args.backend {
520        BatchProvider::Anthropic => {
521            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
522                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
523            });
524            client.pending_batch_count()
525        }
526        BatchProvider::Openai => {
527            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
528                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
529            });
530            client.pending_batch_count()
531        }
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use crate::{PredictionProvider, TeacherBackend};
539    use edit_prediction::example_spec::ExampleSpec;
540    use std::{path::Path, sync::Arc};
541    use zeta_prompt::ZetaFormat;
542
543    fn example_with_previous_prediction() -> Example {
544        Example {
545            spec: ExampleSpec {
546                name: "example".to_string(),
547                repository_url: "https://github.com/zed-industries/zed.git".to_string(),
548                revision: "HEAD".to_string(),
549                tags: Vec::new(),
550                reasoning: None,
551                uncommitted_diff: String::new(),
552                cursor_path: Arc::from(Path::new("src/main.rs")),
553                cursor_position: "0:0".to_string(),
554                edit_history: String::new(),
555                expected_patches: Vec::new(),
556                rejected_patch: None,
557                telemetry: None,
558                human_feedback: Vec::new(),
559                rating: None,
560            },
561            prompt_inputs: None,
562            prompt: None,
563            predictions: vec![ExamplePrediction {
564                actual_patch: Some("previous patch".to_string()),
565                actual_output: String::new(),
566                actual_cursor: Some(ActualCursor {
567                    path: "src/main.rs".to_string(),
568                    row: 1,
569                    column: 2,
570                    offset: 3,
571                    editable_region_offset: Some(4),
572                }),
573                error: None,
574                provider: PredictionProvider::Teacher(
575                    TeacherBackend::Sonnet45,
576                    ZetaFormat::default(),
577                ),
578                cumulative_logprob: None,
579                avg_logprob: None,
580            }],
581            score: Vec::new(),
582            qa: Vec::new(),
583            zed_version: None,
584            state: None,
585        }
586    }
587
588    #[test]
589    fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
590        let example = example_with_previous_prediction();
591        let actual_output = indoc::indoc! {"
592            After reviewing the feedback, the previous prediction is still correct.
593            Use `KEEP_PREVIOUS`.
594
595            ```
596            unrelated trailing code block
597            ```
598        "};
599
600        let (patch, cursor) = parse(&example, actual_output).unwrap();
601
602        assert_eq!(patch, "previous patch");
603        assert_eq!(cursor.unwrap().offset, 3);
604    }
605}