repair.rs

  1//! Repair predictions that received poor quality signals.
  2//!
  3//! This module takes examples with predictions, identifies predictions that need
  4//! improvement, and uses an LLM to generate improved predictions. It supports
  5//! two sources of quality signals:
  6//! - QA feedback (reverts_edits or low confidence)
  7//! - Computed scores when QA is unavailable (high reversal_ratio or wrong_editable_region)
  8
  9use crate::{
 10    BatchProvider, PredictionProvider,
 11    anthropic_client::AnthropicClient,
 12    example::{ActualCursor, Example, ExamplePrediction},
 13    format_prompt::TeacherPrompt,
 14    metrics::count_patch_token_changes,
 15    openai_client::OpenAiClient,
 16    parse_output::run_parse_output,
 17    paths::LLM_CACHE_DB,
 18    progress::{ExampleProgress, Step},
 19    word_diff::unified_to_word_diff,
 20};
 21use anyhow::{Context as _, Result};
 22use std::sync::OnceLock;
 23
 24const KEEP_PREVIOUS: &str = "KEEP_PREVIOUS";
 25
 26/// Print a summary report of repair results across all examples.
 27pub fn print_report(examples: &[Example], confidence_threshold: u8) {
 28    let total = examples.len();
 29    let mut no_repair_needed = 0;
 30    let mut repaired = 0;
 31    let mut repair_failed = 0;
 32
 33    for example in examples {
 34        if !needs_repair(example, confidence_threshold) {
 35            no_repair_needed += 1;
 36            continue;
 37        }
 38
 39        if has_successful_repair(example) {
 40            repaired += 1;
 41        } else {
 42            repair_failed += 1;
 43        }
 44    }
 45
 46    let needed_repair = total - no_repair_needed;
 47
 48    eprintln!();
 49    eprintln!("Repair summary ({total} examples):");
 50    eprintln!(
 51        "  {no_repair_needed}/{total} didn't need repair (confidence > {confidence_threshold})"
 52    );
 53    if needed_repair > 0 {
 54        eprintln!("  {needed_repair}/{total} needed repair:");
 55        if repaired > 0 {
 56            eprintln!("    {repaired} repaired successfully");
 57        }
 58        if repair_failed > 0 {
 59            eprintln!("    {repair_failed} failed to repair");
 60        }
 61    }
 62}
 63
 64/// Arguments for the repair command.
 65#[derive(Debug, Clone, clap::Args)]
 66pub struct RepairArgs {
 67    /// Use synchronous API instead of batch
 68    #[clap(long)]
 69    pub no_batch: bool,
 70
 71    /// Confidence threshold: repair predictions with confidence <= this value (1-5)
 72    #[clap(long, default_value = "2")]
 73    pub confidence_threshold: u8,
 74
 75    /// Which LLM provider to use (anthropic or openai)
 76    #[clap(long, default_value = "anthropic")]
 77    pub backend: BatchProvider,
 78}
 79
 80fn model_for_backend(backend: BatchProvider) -> &'static str {
 81    match backend {
 82        BatchProvider::Anthropic => "claude-sonnet-4-6",
 83        BatchProvider::Openai => "gpt-5.2",
 84    }
 85}
 86
 87/// Build the quality feedback string from QA results.
 88fn build_qa_feedback(example: &Example) -> Option<String> {
 89    let qa = example.qa.first()?.as_ref()?;
 90
 91    let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
 92    let reverts_edits = qa
 93        .reverts_edits
 94        .map_or("unknown", |v| if v { "yes" } else { "no" });
 95    let confidence = qa
 96        .confidence
 97        .map_or("unknown".to_string(), |v| v.to_string());
 98
 99    Some(format!(
100        "- **Reverts user edits**: {reverts_edits}\n\
101         - **Confidence score**: {confidence}/5\n\
102         - **Reasoning**: {qa_reasoning}"
103    ))
104}
105
106/// Build the quality feedback string from computed scores when QA is unavailable.
107fn build_score_feedback(example: &Example) -> Option<String> {
108    let score = example.score.first()?;
109
110    let mut issues = Vec::new();
111
112    if score.reversal_ratio > 0.9 {
113        issues.push(format!(
114            "Automated analysis detected a high reversal ratio ({:.2}), which suggests this \
115             prediction may be reverting changes the user intentionally made. Double-check that \
116             the prediction doesn't undo the user's recent edits. If the prediction is actually \
117             fine and the edits are intentional completions rather than reversals, keep it as-is. \
118             If it truly reverts the user's changes, generate an improved prediction that \
119             continues the user's intent instead.",
120            score.reversal_ratio
121        ));
122    }
123
124    if score.wrong_editable_region == Some(true) {
125        issues.push(
126            "Automated analysis detected that the prediction may be modifying code outside \
127             the expected editable region, or producing changes misaligned with the editable \
128             region boundaries. Make sure the prediction only modifies code within the editable \
129             region and is properly aligned."
130                .to_string(),
131        );
132    }
133
134    if issues.is_empty() {
135        return None;
136    }
137
138    let mut feedback = String::from(
139        "No human quality assessment is available, but automated scoring flagged potential issues:\n\n",
140    );
141    for issue in &issues {
142        feedback.push_str(&format!("- {issue}\n"));
143    }
144    feedback.push_str(
145        "\nRemember: if the previous prediction was actually correct, output `KEEP_PREVIOUS`. \
146         If no edits should be made at all and you are unsure how to improve it, output `NO_EDITS`.",
147    );
148
149    Some(feedback)
150}
151
152/// Build the repair message (Turn 3) for a multi-turn conversation.
153///
154/// This message is sent after the original teacher prompt (Turn 1) and
155/// teacher response (Turn 2) to request an improved prediction.
156pub fn build_repair_message(example: &Example) -> Result<String> {
157    let prediction = example
158        .predictions
159        .first()
160        .context("no predictions available")?;
161    let actual_patch = prediction
162        .actual_patch
163        .as_ref()
164        .context("no actual_patch available (run predict first)")?;
165
166    let quality_feedback = build_qa_feedback(example)
167        .or_else(|| build_score_feedback(example))
168        .context("no quality feedback available (need either QA results or computed scores)")?;
169
170    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
171
172    let token_counts = count_patch_token_changes(actual_patch);
173    let mut token_change_info = format!(
174        "\n## Token Change Statistics\n\n\
175         - **Deleted tokens**: {}\n\
176         - **Inserted tokens**: {}",
177        token_counts.deleted_tokens, token_counts.inserted_tokens,
178    );
179    if token_counts.deleted_tokens > 100 || token_counts.inserted_tokens > 100 {
180        token_change_info.push_str(
181            "\n\n> **Note:** The token change count is high. \
182             Consider producing a more scoped edit that targets only the lines \
183             that truly need to change, rather than rewriting large sections.",
184        );
185    }
186
187    let prompt_template = crate::prompt_assets::get_prompt("repair.md");
188    Ok(prompt_template
189        .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
190        .replace("{quality_feedback}", &quality_feedback)
191        .replace("{token_change_info}", &token_change_info))
192}
193
194/// Check if an example needs repair based on QA feedback or computed scores.
195pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
196    // Check QA-based signals first.
197    if let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) {
198        if qa.reverts_edits == Some(true) {
199            return true;
200        }
201
202        if let Some(confidence) = qa.confidence {
203            if confidence <= confidence_threshold {
204                return true;
205            }
206        }
207
208        return false;
209    }
210
211    // When QA is unavailable, fall back to computed score signals.
212    if let Some(score) = example.score.first() {
213        if score.reversal_ratio > 0.9 {
214            return true;
215        }
216
217        if score.wrong_editable_region == Some(true) {
218            return true;
219        }
220    }
221
222    false
223}
224
225/// Parse repair model output into a patch and optional cursor.
226///
227/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
228/// and delegates normal output to `TeacherPrompt::parse`.
229pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
230    if actual_output.contains(KEEP_PREVIOUS) {
231        let original = example
232            .predictions
233            .first()
234            .context("no original prediction to keep")?;
235        let patch = original.actual_patch.clone().unwrap_or_default();
236        let cursor = original.actual_cursor.clone();
237        return Ok((patch, cursor));
238    }
239
240    TeacherPrompt::parse(example, actual_output)
241}
242
243/// Check if an example already has a successful repair prediction.
244fn has_successful_repair(example: &Example) -> bool {
245    example
246        .predictions
247        .iter()
248        .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
249}
250
251static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
252static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
253static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
254static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
255
256/// Run repair for a single example.
257///
258/// This sends a multi-turn conversation to the LLM:
259/// - Turn 1 (User): Original teacher prompt
260/// - Turn 2 (Assistant): Original teacher response
261/// - Turn 3 (User): Repair critique and instructions
262/// - Turn 4 (Assistant): Improved prediction (the response we parse)
263pub async fn run_repair(
264    example: &mut Example,
265    args: &RepairArgs,
266    example_progress: &ExampleProgress,
267) -> Result<()> {
268    if has_successful_repair(example) {
269        return Ok(());
270    }
271
272    if !needs_repair(example, args.confidence_threshold) {
273        return Ok(());
274    }
275
276    run_parse_output(example).context("Failed to execute run_parse_output")?;
277
278    if example.prompt_inputs.is_none() {
279        anyhow::bail!("prompt_inputs missing (run context retrieval first)");
280    }
281
282    if example.predictions.is_empty() {
283        anyhow::bail!("no predictions available (run predict first)");
284    }
285
286    let teacher_prompt = example
287        .prompt
288        .as_ref()
289        .context("prompt missing (run format_prompt first)")?;
290
291    let teacher_response = &example.predictions[0].actual_output;
292    if teacher_response.is_empty() {
293        anyhow::bail!("teacher response is empty (run predict first)");
294    }
295
296    let step_progress = example_progress.start(Step::Repair);
297
298    let model = model_for_backend(args.backend);
299    let repair_message = build_repair_message(example).context("Failed to build repair message")?;
300
301    step_progress.set_substatus("generating");
302
303    let response = match args.backend {
304        BatchProvider::Anthropic => {
305            let client = if args.no_batch {
306                ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
307                    AnthropicClient::plain().expect("Failed to create Anthropic client")
308                })
309            } else {
310                ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
311                    AnthropicClient::batch(&LLM_CACHE_DB)
312                        .expect("Failed to create Anthropic client")
313                })
314            };
315
316            let messages = vec![
317                // Turn 1: Original teacher prompt
318                anthropic::Message {
319                    role: anthropic::Role::User,
320                    content: vec![anthropic::RequestContent::Text {
321                        text: teacher_prompt.input.clone(),
322                        cache_control: None,
323                    }],
324                },
325                // Turn 2: Original teacher response
326                anthropic::Message {
327                    role: anthropic::Role::Assistant,
328                    content: vec![anthropic::RequestContent::Text {
329                        text: teacher_response.clone(),
330                        cache_control: None,
331                    }],
332                },
333                // Turn 3: Repair critique and instructions
334                anthropic::Message {
335                    role: anthropic::Role::User,
336                    content: vec![anthropic::RequestContent::Text {
337                        text: repair_message,
338                        cache_control: None,
339                    }],
340                },
341            ];
342
343            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
344                return Ok(());
345            };
346
347            response
348                .content
349                .iter()
350                .filter_map(|c| match c {
351                    anthropic::ResponseContent::Text { text } => Some(text.as_str()),
352                    _ => None,
353                })
354                .collect::<Vec<_>>()
355                .join("")
356        }
357        BatchProvider::Openai => {
358            let client = if args.no_batch {
359                OPENAI_CLIENT_PLAIN
360                    .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
361            } else {
362                OPENAI_CLIENT_BATCH.get_or_init(|| {
363                    OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
364                })
365            };
366
367            let messages = vec![
368                // Turn 1: Original teacher prompt
369                open_ai::RequestMessage::User {
370                    content: open_ai::MessageContent::Plain(teacher_prompt.input.clone()),
371                },
372                // Turn 2: Original teacher response
373                open_ai::RequestMessage::Assistant {
374                    content: Some(open_ai::MessageContent::Plain(teacher_response.clone())),
375                    tool_calls: vec![],
376                },
377                // Turn 3: Repair critique and instructions
378                open_ai::RequestMessage::User {
379                    content: open_ai::MessageContent::Plain(repair_message),
380                },
381            ];
382
383            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
384                return Ok(());
385            };
386
387            response
388                .choices
389                .into_iter()
390                .filter_map(|choice| match choice.message {
391                    open_ai::RequestMessage::Assistant { content, .. } => {
392                        content.map(|c| match c {
393                            open_ai::MessageContent::Plain(text) => text,
394                            open_ai::MessageContent::Multipart(parts) => parts
395                                .into_iter()
396                                .filter_map(|p| match p {
397                                    open_ai::MessagePart::Text { text } => Some(text),
398                                    _ => None,
399                                })
400                                .collect::<Vec<_>>()
401                                .join(""),
402                        })
403                    }
404                    _ => None,
405                })
406                .collect::<Vec<_>>()
407                .join("")
408        }
409    };
410
411    let parse_result = parse(example, &response);
412    let err = parse_result
413        .as_ref()
414        .err()
415        .map(|e| format!("Failed to parse repair response: {}", e));
416
417    let (actual_patch, actual_cursor) = parse_result.ok().unzip();
418    let actual_cursor = actual_cursor.flatten();
419
420    example.predictions.push(ExamplePrediction {
421        actual_patch,
422        actual_output: response,
423        actual_cursor,
424        error: err,
425        provider: PredictionProvider::Repair,
426        cumulative_logprob: None,
427        avg_logprob: None,
428    });
429
430    Ok(())
431}
432
433/// Sync batches for repair (upload pending requests, download finished results).
434pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
435    if args.no_batch {
436        return Ok(());
437    }
438
439    match args.backend {
440        BatchProvider::Anthropic => {
441            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
442                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
443            });
444            client.sync_batches().await?;
445        }
446        BatchProvider::Openai => {
447            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
448                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
449            });
450            client.sync_batches().await?;
451        }
452    }
453
454    Ok(())
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::{PredictionProvider, TeacherBackend};
461    use edit_prediction::example_spec::ExampleSpec;
462    use std::{path::Path, sync::Arc};
463
464    fn example_with_previous_prediction() -> Example {
465        Example {
466            spec: ExampleSpec {
467                name: "example".to_string(),
468                repository_url: "https://github.com/zed-industries/zed.git".to_string(),
469                revision: "HEAD".to_string(),
470                tags: Vec::new(),
471                reasoning: None,
472                uncommitted_diff: String::new(),
473                cursor_path: Arc::from(Path::new("src/main.rs")),
474                cursor_position: "0:0".to_string(),
475                edit_history: String::new(),
476                expected_patches: Vec::new(),
477                rejected_patch: None,
478                telemetry: None,
479                human_feedback: Vec::new(),
480                rating: None,
481            },
482            prompt_inputs: None,
483            prompt: None,
484            predictions: vec![ExamplePrediction {
485                actual_patch: Some("previous patch".to_string()),
486                actual_output: String::new(),
487                actual_cursor: Some(ActualCursor {
488                    path: "src/main.rs".to_string(),
489                    row: 1,
490                    column: 2,
491                    offset: 3,
492                    editable_region_offset: Some(4),
493                }),
494                error: None,
495                provider: PredictionProvider::Teacher(TeacherBackend::Sonnet45),
496                cumulative_logprob: None,
497                avg_logprob: None,
498            }],
499            score: Vec::new(),
500            qa: Vec::new(),
501            zed_version: None,
502            state: None,
503        }
504    }
505
506    #[test]
507    fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
508        let example = example_with_previous_prediction();
509        let actual_output = indoc::indoc! {"
510            After reviewing the feedback, the previous prediction is still correct.
511            Use `KEEP_PREVIOUS`.
512
513            ```
514            unrelated trailing code block
515            ```
516        "};
517
518        let (patch, cursor) = parse(&example, actual_output).unwrap();
519
520        assert_eq!(patch, "previous patch");
521        assert_eq!(cursor.unwrap().offset, 3);
522    }
523}