repair.rs

  1//! Repair predictions that received poor QA scores.
  2//!
  3//! This module takes examples with predictions and QA feedback, identifies
  4//! predictions that need improvement (based on reverts_edits or low confidence),
  5//! and uses an LLM to generate improved predictions.
  6
  7use crate::{
  8    BatchProvider, PredictionProvider,
  9    anthropic_client::AnthropicClient,
 10    example::{Example, ExamplePrediction},
 11    format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example},
 12    openai_client::OpenAiClient,
 13    parse_output::run_parse_output,
 14    paths::LLM_CACHE_DB,
 15    progress::{ExampleProgress, Step},
 16    word_diff::unified_to_word_diff,
 17};
 18use anyhow::{Context as _, Result};
 19use std::sync::OnceLock;
 20
 21/// Arguments for the repair command.
 22#[derive(Debug, Clone, clap::Args)]
 23pub struct RepairArgs {
 24    /// Use synchronous API instead of batch
 25    #[clap(long)]
 26    pub no_batch: bool,
 27
 28    /// Confidence threshold: repair predictions with confidence <= this value (1-5)
 29    #[clap(long, default_value = "2")]
 30    pub confidence_threshold: u8,
 31
 32    /// Which LLM provider to use (anthropic or openai)
 33    #[clap(long, default_value = "anthropic")]
 34    pub backend: BatchProvider,
 35}
 36
 37fn model_for_backend(backend: BatchProvider) -> &'static str {
 38    match backend {
 39        BatchProvider::Anthropic => "claude-sonnet-4-5",
 40        BatchProvider::Openai => "gpt-5.2",
 41    }
 42}
 43
 44/// Build the repair prompt for an example that needs improvement.
 45pub fn build_repair_prompt(example: &Example) -> Result<String> {
 46    let prediction = example
 47        .predictions
 48        .first()
 49        .context("no predictions available")?;
 50    let qa = example
 51        .qa
 52        .first()
 53        .context("no QA results available")?
 54        .as_ref()
 55        .context("QA result is None")?;
 56    let prompt_inputs = example
 57        .prompt_inputs
 58        .as_ref()
 59        .context("prompt_inputs missing (run context retrieval first)")?;
 60    let actual_patch = prediction
 61        .actual_patch
 62        .as_ref()
 63        .context("no actual_patch available (run predict first)")?;
 64
 65    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
 66
 67    let mut edit_history = String::new();
 68    for event in &prompt_inputs.edit_history {
 69        match event.as_ref() {
 70            zeta_prompt::Event::BufferChange {
 71                path,
 72                old_path,
 73                diff,
 74                predicted: _,
 75                in_open_source_repo: _,
 76            } => {
 77                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
 78                edit_history.push_str(&format!("+++ b{}\n", path.display()));
 79                let diff_word_diff = unified_to_word_diff(diff);
 80                edit_history.push_str(&diff_word_diff);
 81                edit_history.push_str("\n\n");
 82            }
 83        }
 84    }
 85
 86    let context = TeacherPrompt::format_context(example);
 87
 88    let cursor_excerpt =
 89        extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?;
 90
 91    let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
 92    let reverts_edits = qa
 93        .reverts_edits
 94        .map_or("unknown", |v| if v { "yes" } else { "no" });
 95    let confidence = qa
 96        .confidence
 97        .map_or("unknown".to_string(), |v| v.to_string());
 98
 99    let prompt_template = crate::prompt_assets::get_prompt("repair.md");
100    Ok(prompt_template
101        .replace("{edit_history}", &edit_history)
102        .replace("{context}", &context)
103        .replace("{cursor_excerpt}", &cursor_excerpt)
104        .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
105        .replace("{reverts_edits}", reverts_edits)
106        .replace("{confidence}", &confidence)
107        .replace("{qa_reasoning}", qa_reasoning))
108}
109
110/// Check if an example needs repair based on QA feedback.
111pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
112    let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) else {
113        return false;
114    };
115
116    if qa.reverts_edits == Some(true) {
117        return true;
118    }
119
120    if let Some(confidence) = qa.confidence {
121        if confidence <= confidence_threshold {
122            return true;
123        }
124    }
125
126    false
127}
128
129/// Check if an example already has a successful repair prediction.
130fn has_successful_repair(example: &Example) -> bool {
131    example
132        .predictions
133        .iter()
134        .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
135}
136
137static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
138static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
139static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
140static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
141
142/// Run repair for a single example.
143pub async fn run_repair(
144    example: &mut Example,
145    args: &RepairArgs,
146    example_progress: &ExampleProgress,
147) -> Result<()> {
148    if has_successful_repair(example) {
149        return Ok(());
150    }
151
152    if !needs_repair(example, args.confidence_threshold) {
153        return Ok(());
154    }
155
156    run_parse_output(example).context("Failed to execute run_parse_output")?;
157
158    if example.prompt_inputs.is_none() {
159        anyhow::bail!("prompt_inputs missing (run context retrieval first)");
160    }
161
162    if example.predictions.is_empty() {
163        anyhow::bail!("no predictions available (run predict first)");
164    }
165
166    if example.qa.is_empty() {
167        anyhow::bail!("no QA results available (run qa first)");
168    }
169
170    let step_progress = example_progress.start(Step::Repair);
171
172    let model = model_for_backend(args.backend);
173    let prompt = build_repair_prompt(example).context("Failed to build repair prompt")?;
174
175    step_progress.set_substatus("generating");
176
177    let response = match args.backend {
178        BatchProvider::Anthropic => {
179            let client = if args.no_batch {
180                ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
181                    AnthropicClient::plain().expect("Failed to create Anthropic client")
182                })
183            } else {
184                ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
185                    AnthropicClient::batch(&LLM_CACHE_DB)
186                        .expect("Failed to create Anthropic client")
187                })
188            };
189
190            let messages = vec![anthropic::Message {
191                role: anthropic::Role::User,
192                content: vec![anthropic::RequestContent::Text {
193                    text: prompt,
194                    cache_control: None,
195                }],
196            }];
197
198            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
199                return Ok(());
200            };
201
202            response
203                .content
204                .iter()
205                .filter_map(|c| match c {
206                    anthropic::ResponseContent::Text { text } => Some(text.as_str()),
207                    _ => None,
208                })
209                .collect::<Vec<_>>()
210                .join("")
211        }
212        BatchProvider::Openai => {
213            let client = if args.no_batch {
214                OPENAI_CLIENT_PLAIN
215                    .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
216            } else {
217                OPENAI_CLIENT_BATCH.get_or_init(|| {
218                    OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
219                })
220            };
221
222            let messages = vec![open_ai::RequestMessage::User {
223                content: open_ai::MessageContent::Plain(prompt),
224            }];
225
226            let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
227                return Ok(());
228            };
229
230            response
231                .choices
232                .into_iter()
233                .filter_map(|choice| match choice.message {
234                    open_ai::RequestMessage::Assistant { content, .. } => {
235                        content.map(|c| match c {
236                            open_ai::MessageContent::Plain(text) => text,
237                            open_ai::MessageContent::Multipart(parts) => parts
238                                .into_iter()
239                                .filter_map(|p| match p {
240                                    open_ai::MessagePart::Text { text } => Some(text),
241                                    _ => None,
242                                })
243                                .collect::<Vec<_>>()
244                                .join(""),
245                        })
246                    }
247                    _ => None,
248                })
249                .collect::<Vec<_>>()
250                .join("")
251        }
252    };
253
254    let parse_result = TeacherPrompt::parse(example, &response);
255    let err = parse_result
256        .as_ref()
257        .err()
258        .map(|e| format!("Failed to parse repair response: {}", e));
259
260    let (actual_patch, actual_cursor_offset) = parse_result.ok().unzip();
261
262    example.predictions.push(ExamplePrediction {
263        actual_patch,
264        actual_output: response,
265        actual_cursor_offset: actual_cursor_offset.flatten(),
266        error: err,
267        provider: PredictionProvider::Repair,
268    });
269
270    Ok(())
271}
272
273/// Sync batches for repair (upload pending requests, download finished results).
274pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
275    if args.no_batch {
276        return Ok(());
277    }
278
279    match args.backend {
280        BatchProvider::Anthropic => {
281            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
282                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
283            });
284            client.sync_batches().await?;
285        }
286        BatchProvider::Openai => {
287            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
288                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
289            });
290            client.sync_batches().await?;
291        }
292    }
293
294    Ok(())
295}