repair.rs

  1//! Repair predictions that received poor quality signals.
  2//!
  3//! This module takes examples with predictions, identifies predictions that need
  4//! improvement, and uses an LLM to generate improved predictions. It supports
  5//! two sources of quality signals:
  6//! - QA feedback (reverts_edits or low confidence)
  7//! - Computed scores when QA is unavailable (high reversal_ratio or wrong_editable_region)
  8
  9use crate::{
 10    BatchProvider, PredictionProvider,
 11    anthropic_client::AnthropicClient,
 12    example::{ActualCursor, Example, ExamplePrediction},
 13    format_prompt::TeacherPrompt,
 14    metrics::count_patch_token_changes,
 15    openai_client::OpenAiClient,
 16    parse_output::run_parse_output,
 17    paths::LLM_CACHE_DB,
 18    progress::{ExampleProgress, Progress, Step},
 19    word_diff::unified_to_word_diff,
 20};
 21use anyhow::{Context as _, Result};
 22use std::sync::OnceLock;
 23
 24const KEEP_PREVIOUS: &str = "KEEP_PREVIOUS";
 25
 26/// Print a summary report of repair results across all examples.
 27pub fn print_report(examples: &[Example], confidence_threshold: u8) {
 28    let total = examples.len();
 29    let mut no_repair_needed = 0;
 30    let mut repaired = 0;
 31    let mut repair_failed = 0;
 32
 33    for example in examples {
 34        if !needs_repair(example, confidence_threshold) {
 35            no_repair_needed += 1;
 36            continue;
 37        }
 38
 39        if has_successful_repair(example) {
 40            repaired += 1;
 41        } else {
 42            repair_failed += 1;
 43        }
 44    }
 45
 46    let needed_repair = total - no_repair_needed;
 47
 48    eprintln!();
 49    eprintln!("Repair summary ({total} examples):");
 50    eprintln!(
 51        "  {no_repair_needed}/{total} didn't need repair (confidence > {confidence_threshold})"
 52    );
 53    if needed_repair > 0 {
 54        eprintln!("  {needed_repair}/{total} needed repair:");
 55        if repaired > 0 {
 56            eprintln!("    {repaired} repaired successfully");
 57        }
 58        if repair_failed > 0 {
 59            eprintln!("    {repair_failed} failed to repair");
 60        }
 61    }
 62}
 63
 64/// Arguments for the repair command.
 65#[derive(Debug, Clone, clap::Args)]
 66pub struct RepairArgs {
 67    /// Use synchronous API instead of batch
 68    #[clap(long)]
 69    pub no_batch: bool,
 70
 71    /// Confidence threshold: repair predictions with confidence <= this value (1-5)
 72    #[clap(long, default_value = "2")]
 73    pub confidence_threshold: u8,
 74
 75    /// Which LLM provider to use (anthropic or openai)
 76    #[clap(long, default_value = "anthropic")]
 77    pub backend: BatchProvider,
 78    /// Wait for all batches to complete before exiting
 79    #[clap(long)]
 80    pub wait: bool,
 81}
 82
 83fn model_for_backend(backend: BatchProvider) -> &'static str {
 84    match backend {
 85        BatchProvider::Anthropic => "claude-sonnet-4-6",
 86        BatchProvider::Openai => "gpt-5.2",
 87    }
 88}
 89
 90/// Build the quality feedback string from QA results.
 91fn build_qa_feedback(example: &Example) -> Option<String> {
 92    let qa = example.qa.first()?.as_ref()?;
 93
 94    let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
 95    let reverts_edits = qa
 96        .reverts_edits
 97        .map_or("unknown", |v| if v { "yes" } else { "no" });
 98    let confidence = qa
 99        .confidence
100        .map_or("unknown".to_string(), |v| v.to_string());
101
102    Some(format!(
103        "- **Reverts user edits**: {reverts_edits}\n\
104         - **Confidence score**: {confidence}/5\n\
105         - **Reasoning**: {qa_reasoning}"
106    ))
107}
108
109/// Build the quality feedback string from computed scores when QA is unavailable.
110fn build_score_feedback(example: &Example) -> Option<String> {
111    let score = example.score.first()?;
112
113    let mut issues = Vec::new();
114
115    if score.reversal_ratio > 0.9 {
116        issues.push(format!(
117            "Automated analysis detected a high reversal ratio ({:.2}), which suggests this \
118             prediction may be reverting changes the user intentionally made. Double-check that \
119             the prediction doesn't undo the user's recent edits. If the prediction is actually \
120             fine and the edits are intentional completions rather than reversals, keep it as-is. \
121             If it truly reverts the user's changes, generate an improved prediction that \
122             continues the user's intent instead.",
123            score.reversal_ratio
124        ));
125    }
126
127    if score.wrong_editable_region == Some(true) {
128        issues.push(
129            "Automated analysis detected that the prediction may be modifying code outside \
130             the expected editable region, or producing changes misaligned with the editable \
131             region boundaries. Make sure the prediction only modifies code within the editable \
132             region and is properly aligned."
133                .to_string(),
134        );
135    }
136
137    if issues.is_empty() {
138        return None;
139    }
140
141    let mut feedback = String::from(
142        "No human quality assessment is available, but automated scoring flagged potential issues:\n\n",
143    );
144    for issue in &issues {
145        feedback.push_str(&format!("- {issue}\n"));
146    }
147    feedback.push_str(
148        "\nRemember: if the previous prediction was actually correct, output `KEEP_PREVIOUS`. \
149         If no edits should be made at all and you are unsure how to improve it, output `NO_EDITS`.",
150    );
151
152    Some(feedback)
153}
154
155/// Build the repair message (Turn 3) for a multi-turn conversation.
156///
157/// This message is sent after the original teacher prompt (Turn 1) and
158/// teacher response (Turn 2) to request an improved prediction.
159pub fn build_repair_message(example: &Example) -> Result<String> {
160    let prediction = example
161        .predictions
162        .first()
163        .context("no predictions available")?;
164    let actual_patch = prediction
165        .actual_patch
166        .as_ref()
167        .context("no actual_patch available (run predict first)")?;
168
169    let quality_feedback = build_qa_feedback(example)
170        .or_else(|| build_score_feedback(example))
171        .context("no quality feedback available (need either QA results or computed scores)")?;
172
173    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
174
175    let token_counts = count_patch_token_changes(actual_patch);
176    let mut token_change_info = format!(
177        "\n## Token Change Statistics\n\n\
178         - **Deleted tokens**: {}\n\
179         - **Inserted tokens**: {}",
180        token_counts.deleted_tokens, token_counts.inserted_tokens,
181    );
182    if token_counts.deleted_tokens > 100 || token_counts.inserted_tokens > 100 {
183        token_change_info.push_str(
184            "\n\n> **Note:** The token change count is high. \
185             Consider producing a more scoped edit that targets only the lines \
186             that truly need to change, rather than rewriting large sections.",
187        );
188    }
189
190    let prompt_template = crate::prompt_assets::get_prompt("repair.md");
191    Ok(prompt_template
192        .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
193        .replace("{quality_feedback}", &quality_feedback)
194        .replace("{token_change_info}", &token_change_info))
195}
196
197/// Check if an example needs repair based on QA feedback or computed scores.
198pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
199    // Check QA-based signals first.
200    if let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) {
201        if qa.reverts_edits == Some(true) {
202            return true;
203        }
204
205        if let Some(confidence) = qa.confidence {
206            if confidence <= confidence_threshold {
207                return true;
208            }
209        }
210
211        return false;
212    }
213
214    // When QA is unavailable, fall back to computed score signals.
215    if let Some(score) = example.score.first() {
216        if score.reversal_ratio > 0.9 {
217            return true;
218        }
219
220        if score.wrong_editable_region == Some(true) {
221            return true;
222        }
223    }
224
225    false
226}
227
228/// Parse repair model output into a patch and optional cursor.
229///
230/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
231/// and delegates normal output to `TeacherPrompt::parse`.
232pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
233    if actual_output.contains(KEEP_PREVIOUS) {
234        let original = example
235            .predictions
236            .first()
237            .context("no original prediction to keep")?;
238        let patch = original.actual_patch.clone().unwrap_or_default();
239        let cursor = original.actual_cursor.clone();
240        return Ok((patch, cursor));
241    }
242
243    TeacherPrompt::parse(example, actual_output)
244}
245
246/// Check if an example already has a successful repair prediction.
247fn has_successful_repair(example: &Example) -> bool {
248    example
249        .predictions
250        .iter()
251        .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
252}
253
254static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
255static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
256static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
257static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
258
259/// Run repair for a single example.
260///
261/// This sends a multi-turn conversation to the LLM:
262/// - Turn 1 (User): Original teacher prompt
263/// - Turn 2 (Assistant): Original teacher response
264/// - Turn 3 (User): Repair critique and instructions
265/// - Turn 4 (Assistant): Improved prediction (the response we parse)
266pub async fn run_repair(
267    example: &mut Example,
268    args: &RepairArgs,
269    example_progress: &ExampleProgress,
270) -> Result<()> {
271    if has_successful_repair(example) {
272        return Ok(());
273    }
274
275    if !needs_repair(example, args.confidence_threshold) {
276        return Ok(());
277    }
278
279    run_parse_output(example).context("Failed to execute run_parse_output")?;
280
281    if example.prompt_inputs.is_none() {
282        anyhow::bail!("prompt_inputs missing (run context retrieval first)");
283    }
284
285    if example.predictions.is_empty() {
286        anyhow::bail!("no predictions available (run predict first)");
287    }
288
289    let teacher_prompt = example
290        .prompt
291        .as_ref()
292        .context("prompt missing (run format_prompt first)")?;
293
294    let teacher_response = &example.predictions[0].actual_output;
295    if teacher_response.is_empty() {
296        anyhow::bail!("teacher response is empty (run predict first)");
297    }
298
299    let step_progress = example_progress.start(Step::Repair);
300
301    let model = model_for_backend(args.backend);
302    let repair_message = build_repair_message(example).context("Failed to build repair message")?;
303
304    step_progress.set_substatus("generating");
305
306    let response = match args.backend {
307        BatchProvider::Anthropic => {
308            let client = if args.no_batch {
309                ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
310                    AnthropicClient::plain().expect("Failed to create Anthropic client")
311                })
312            } else {
313                ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
314                    AnthropicClient::batch(&LLM_CACHE_DB)
315                        .expect("Failed to create Anthropic client")
316                })
317            };
318
319            let messages = vec![
320                // Turn 1: Original teacher prompt
321                anthropic::Message {
322                    role: anthropic::Role::User,
323                    content: vec![anthropic::RequestContent::Text {
324                        text: teacher_prompt.input.clone(),
325                        cache_control: None,
326                    }],
327                },
328                // Turn 2: Original teacher response
329                anthropic::Message {
330                    role: anthropic::Role::Assistant,
331                    content: vec![anthropic::RequestContent::Text {
332                        text: teacher_response.clone(),
333                        cache_control: None,
334                    }],
335                },
336                // Turn 3: Repair critique and instructions
337                anthropic::Message {
338                    role: anthropic::Role::User,
339                    content: vec![anthropic::RequestContent::Text {
340                        text: repair_message,
341                        cache_control: None,
342                    }],
343                },
344            ];
345
346            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
347                return Ok(());
348            };
349
350            response
351                .content
352                .iter()
353                .filter_map(|c| match c {
354                    anthropic::ResponseContent::Text { text } => Some(text.as_str()),
355                    _ => None,
356                })
357                .collect::<Vec<_>>()
358                .join("")
359        }
360        BatchProvider::Openai => {
361            let client = if args.no_batch {
362                OPENAI_CLIENT_PLAIN
363                    .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
364            } else {
365                OPENAI_CLIENT_BATCH.get_or_init(|| {
366                    OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
367                })
368            };
369
370            let messages = vec![
371                // Turn 1: Original teacher prompt
372                open_ai::RequestMessage::User {
373                    content: open_ai::MessageContent::Plain(teacher_prompt.input.clone()),
374                },
375                // Turn 2: Original teacher response
376                open_ai::RequestMessage::Assistant {
377                    content: Some(open_ai::MessageContent::Plain(teacher_response.clone())),
378                    tool_calls: vec![],
379                },
380                // Turn 3: Repair critique and instructions
381                open_ai::RequestMessage::User {
382                    content: open_ai::MessageContent::Plain(repair_message),
383                },
384            ];
385
386            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
387                return Ok(());
388            };
389
390            response
391                .choices
392                .into_iter()
393                .filter_map(|choice| match choice.message {
394                    open_ai::RequestMessage::Assistant { content, .. } => {
395                        content.map(|c| match c {
396                            open_ai::MessageContent::Plain(text) => text,
397                            open_ai::MessageContent::Multipart(parts) => parts
398                                .into_iter()
399                                .filter_map(|p| match p {
400                                    open_ai::MessagePart::Text { text } => Some(text),
401                                    _ => None,
402                                })
403                                .collect::<Vec<_>>()
404                                .join(""),
405                        })
406                    }
407                    _ => None,
408                })
409                .collect::<Vec<_>>()
410                .join("")
411        }
412    };
413
414    let parse_result = parse(example, &response);
415    let err = parse_result
416        .as_ref()
417        .err()
418        .map(|e| format!("Failed to parse repair response: {}", e));
419
420    let (actual_patch, actual_cursor) = parse_result.ok().unzip();
421    let actual_cursor = actual_cursor.flatten();
422
423    example.predictions.push(ExamplePrediction {
424        actual_patch,
425        actual_output: response,
426        actual_cursor,
427        error: err,
428        provider: PredictionProvider::Repair,
429        cumulative_logprob: None,
430        avg_logprob: None,
431    });
432
433    Ok(())
434}
435
436/// Sync batches for repair (upload pending requests, download finished results).
437pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
438    if args.no_batch {
439        return Ok(());
440    }
441
442    match args.backend {
443        BatchProvider::Anthropic => {
444            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
445                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
446            });
447            client.sync_batches().await?;
448        }
449        BatchProvider::Openai => {
450            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
451                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
452            });
453            client.sync_batches().await?;
454        }
455    }
456
457    Ok(())
458}
459
460pub async fn reprocess_after_batch_wait(examples: &mut [Example], args: &RepairArgs) -> Result<()> {
461    let mut reprocessed = 0;
462    for example in examples.iter_mut() {
463        if has_successful_repair(example) || !needs_repair(example, args.confidence_threshold) {
464            continue;
465        }
466
467        let example_progress = Progress::global().start_group(&example.spec.name);
468        run_repair(example, args, &example_progress).await?;
469        reprocessed += 1;
470    }
471
472    if reprocessed > 0 {
473        eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
474    }
475
476    Ok(())
477}
478
479pub async fn wait_for_batches(args: &RepairArgs) -> Result<()> {
480    if args.no_batch {
481        return Ok(());
482    }
483
484    let poll_interval = std::time::Duration::from_secs(30);
485
486    loop {
487        let pending = pending_batch_count(args)?;
488        if pending == 0 {
489            break;
490        }
491
492        eprintln!(
493            "Waiting for {} pending repair batch request(s) to complete... (polling every {}s)",
494            pending,
495            poll_interval.as_secs()
496        );
497        std::thread::sleep(poll_interval);
498
499        sync_batches(args).await?;
500    }
501
502    Ok(())
503}
504
505fn pending_batch_count(args: &RepairArgs) -> Result<usize> {
506    match args.backend {
507        BatchProvider::Anthropic => {
508            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
509                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
510            });
511            client.pending_batch_count()
512        }
513        BatchProvider::Openai => {
514            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
515                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
516            });
517            client.pending_batch_count()
518        }
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use crate::{PredictionProvider, TeacherBackend};
526    use edit_prediction::example_spec::ExampleSpec;
527    use std::{path::Path, sync::Arc};
528
529    fn example_with_previous_prediction() -> Example {
530        Example {
531            spec: ExampleSpec {
532                name: "example".to_string(),
533                repository_url: "https://github.com/zed-industries/zed.git".to_string(),
534                revision: "HEAD".to_string(),
535                tags: Vec::new(),
536                reasoning: None,
537                uncommitted_diff: String::new(),
538                cursor_path: Arc::from(Path::new("src/main.rs")),
539                cursor_position: "0:0".to_string(),
540                edit_history: String::new(),
541                expected_patches: Vec::new(),
542                rejected_patch: None,
543                telemetry: None,
544                human_feedback: Vec::new(),
545                rating: None,
546            },
547            prompt_inputs: None,
548            prompt: None,
549            predictions: vec![ExamplePrediction {
550                actual_patch: Some("previous patch".to_string()),
551                actual_output: String::new(),
552                actual_cursor: Some(ActualCursor {
553                    path: "src/main.rs".to_string(),
554                    row: 1,
555                    column: 2,
556                    offset: 3,
557                    editable_region_offset: Some(4),
558                }),
559                error: None,
560                provider: PredictionProvider::Teacher(TeacherBackend::Sonnet45),
561                cumulative_logprob: None,
562                avg_logprob: None,
563            }],
564            score: Vec::new(),
565            qa: Vec::new(),
566            zed_version: None,
567            state: None,
568        }
569    }
570
571    #[test]
572    fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
573        let example = example_with_previous_prediction();
574        let actual_output = indoc::indoc! {"
575            After reviewing the feedback, the previous prediction is still correct.
576            Use `KEEP_PREVIOUS`.
577
578            ```
579            unrelated trailing code block
580            ```
581        "};
582
583        let (patch, cursor) = parse(&example, actual_output).unwrap();
584
585        assert_eq!(patch, "previous patch");
586        assert_eq!(cursor.unwrap().offset, 3);
587    }
588}