1//! Repair predictions that received poor QA scores.
2//!
3//! This module takes examples with predictions and QA feedback, identifies
4//! predictions that need improvement (based on reverts_edits or low confidence),
5//! and uses an LLM to generate improved predictions.
6
7use crate::{
8 BatchProvider, PredictionProvider,
9 anthropic_client::AnthropicClient,
10 example::{Example, ExamplePrediction},
11 format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example},
12 openai_client::OpenAiClient,
13 parse_output::run_parse_output,
14 paths::LLM_CACHE_DB,
15 progress::{ExampleProgress, Step},
16 word_diff::unified_to_word_diff,
17};
18use anyhow::{Context as _, Result};
19use std::sync::OnceLock;
20
21/// Arguments for the repair command.
22#[derive(Debug, Clone, clap::Args)]
23pub struct RepairArgs {
24 /// Use synchronous API instead of batch
25 #[clap(long)]
26 pub no_batch: bool,
27
28 /// Confidence threshold: repair predictions with confidence <= this value (1-5)
29 #[clap(long, default_value = "2")]
30 pub confidence_threshold: u8,
31
32 /// Which LLM provider to use (anthropic or openai)
33 #[clap(long, default_value = "anthropic")]
34 pub backend: BatchProvider,
35}
36
37fn model_for_backend(backend: BatchProvider) -> &'static str {
38 match backend {
39 BatchProvider::Anthropic => "claude-sonnet-4-5",
40 BatchProvider::Openai => "gpt-5.2",
41 }
42}
43
44/// Build the repair prompt for an example that needs improvement.
45pub fn build_repair_prompt(example: &Example) -> Result<String> {
46 let prediction = example
47 .predictions
48 .first()
49 .context("no predictions available")?;
50 let qa = example
51 .qa
52 .first()
53 .context("no QA results available")?
54 .as_ref()
55 .context("QA result is None")?;
56 let prompt_inputs = example
57 .prompt_inputs
58 .as_ref()
59 .context("prompt_inputs missing (run context retrieval first)")?;
60 let actual_patch = prediction
61 .actual_patch
62 .as_ref()
63 .context("no actual_patch available (run predict first)")?;
64
65 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
66
67 let mut edit_history = String::new();
68 for event in &prompt_inputs.edit_history {
69 match event.as_ref() {
70 zeta_prompt::Event::BufferChange {
71 path,
72 old_path,
73 diff,
74 predicted: _,
75 in_open_source_repo: _,
76 } => {
77 edit_history.push_str(&format!("--- a{}\n", old_path.display()));
78 edit_history.push_str(&format!("+++ b{}\n", path.display()));
79 let diff_word_diff = unified_to_word_diff(diff);
80 edit_history.push_str(&diff_word_diff);
81 edit_history.push_str("\n\n");
82 }
83 }
84 }
85
86 let context = TeacherPrompt::format_context(example);
87
88 let cursor_excerpt =
89 extract_cursor_excerpt_from_example(example).context("failed to extract cursor excerpt")?;
90
91 let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
92 let reverts_edits = qa
93 .reverts_edits
94 .map_or("unknown", |v| if v { "yes" } else { "no" });
95 let confidence = qa
96 .confidence
97 .map_or("unknown".to_string(), |v| v.to_string());
98
99 let prompt_template = crate::prompt_assets::get_prompt("repair.md");
100 Ok(prompt_template
101 .replace("{edit_history}", &edit_history)
102 .replace("{context}", &context)
103 .replace("{cursor_excerpt}", &cursor_excerpt)
104 .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
105 .replace("{reverts_edits}", reverts_edits)
106 .replace("{confidence}", &confidence)
107 .replace("{qa_reasoning}", qa_reasoning))
108}
109
110/// Check if an example needs repair based on QA feedback.
111pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
112 let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) else {
113 return false;
114 };
115
116 if qa.reverts_edits == Some(true) {
117 return true;
118 }
119
120 if let Some(confidence) = qa.confidence {
121 if confidence <= confidence_threshold {
122 return true;
123 }
124 }
125
126 false
127}
128
129/// Check if an example already has a successful repair prediction.
130fn has_successful_repair(example: &Example) -> bool {
131 example
132 .predictions
133 .iter()
134 .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
135}
136
137static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
138static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
139static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
140static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
141
142/// Run repair for a single example.
143pub async fn run_repair(
144 example: &mut Example,
145 args: &RepairArgs,
146 example_progress: &ExampleProgress,
147) -> Result<()> {
148 if has_successful_repair(example) {
149 return Ok(());
150 }
151
152 if !needs_repair(example, args.confidence_threshold) {
153 return Ok(());
154 }
155
156 run_parse_output(example).context("Failed to execute run_parse_output")?;
157
158 if example.prompt_inputs.is_none() {
159 anyhow::bail!("prompt_inputs missing (run context retrieval first)");
160 }
161
162 if example.predictions.is_empty() {
163 anyhow::bail!("no predictions available (run predict first)");
164 }
165
166 if example.qa.is_empty() {
167 anyhow::bail!("no QA results available (run qa first)");
168 }
169
170 let step_progress = example_progress.start(Step::Repair);
171
172 let model = model_for_backend(args.backend);
173 let prompt = build_repair_prompt(example).context("Failed to build repair prompt")?;
174
175 step_progress.set_substatus("generating");
176
177 let response = match args.backend {
178 BatchProvider::Anthropic => {
179 let client = if args.no_batch {
180 ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
181 AnthropicClient::plain().expect("Failed to create Anthropic client")
182 })
183 } else {
184 ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
185 AnthropicClient::batch(&LLM_CACHE_DB)
186 .expect("Failed to create Anthropic client")
187 })
188 };
189
190 let messages = vec![anthropic::Message {
191 role: anthropic::Role::User,
192 content: vec![anthropic::RequestContent::Text {
193 text: prompt,
194 cache_control: None,
195 }],
196 }];
197
198 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
199 return Ok(());
200 };
201
202 response
203 .content
204 .iter()
205 .filter_map(|c| match c {
206 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
207 _ => None,
208 })
209 .collect::<Vec<_>>()
210 .join("")
211 }
212 BatchProvider::Openai => {
213 let client = if args.no_batch {
214 OPENAI_CLIENT_PLAIN
215 .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
216 } else {
217 OPENAI_CLIENT_BATCH.get_or_init(|| {
218 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
219 })
220 };
221
222 let messages = vec![open_ai::RequestMessage::User {
223 content: open_ai::MessageContent::Plain(prompt),
224 }];
225
226 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
227 return Ok(());
228 };
229
230 response
231 .choices
232 .into_iter()
233 .filter_map(|choice| match choice.message {
234 open_ai::RequestMessage::Assistant { content, .. } => {
235 content.map(|c| match c {
236 open_ai::MessageContent::Plain(text) => text,
237 open_ai::MessageContent::Multipart(parts) => parts
238 .into_iter()
239 .filter_map(|p| match p {
240 open_ai::MessagePart::Text { text } => Some(text),
241 _ => None,
242 })
243 .collect::<Vec<_>>()
244 .join(""),
245 })
246 }
247 _ => None,
248 })
249 .collect::<Vec<_>>()
250 .join("")
251 }
252 };
253
254 let parse_result = TeacherPrompt::parse(example, &response);
255 let err = parse_result
256 .as_ref()
257 .err()
258 .map(|e| format!("Failed to parse repair response: {}", e));
259
260 let (actual_patch, actual_cursor_offset) = parse_result.ok().unzip();
261
262 example.predictions.push(ExamplePrediction {
263 actual_patch,
264 actual_output: response,
265 actual_cursor_offset: actual_cursor_offset.flatten(),
266 error: err,
267 provider: PredictionProvider::Repair,
268 });
269
270 Ok(())
271}
272
273/// Sync batches for repair (upload pending requests, download finished results).
274pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
275 if args.no_batch {
276 return Ok(());
277 }
278
279 match args.backend {
280 BatchProvider::Anthropic => {
281 let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
282 AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
283 });
284 client.sync_batches().await?;
285 }
286 BatchProvider::Openai => {
287 let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
288 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
289 });
290 client.sync_batches().await?;
291 }
292 }
293
294 Ok(())
295}