1//! Repair predictions that received poor quality signals.
2//!
3//! This module takes examples with predictions, identifies predictions that need
4//! improvement, and uses an LLM to generate improved predictions. It supports
5//! two sources of quality signals:
6//! - QA feedback (reverts_edits or low confidence)
7//! - Computed scores when QA is unavailable (high reversal_ratio or wrong_editable_region)
8
9use crate::{
10 BatchProvider, PredictionProvider,
11 anthropic_client::AnthropicClient,
12 example::{ActualCursor, Example, ExamplePrediction},
13 format_prompt::TeacherPrompt,
14 metrics::count_patch_token_changes,
15 openai_client::OpenAiClient,
16 parse_output::run_parse_output,
17 paths::LLM_CACHE_DB,
18 progress::{ExampleProgress, Step},
19 word_diff::unified_to_word_diff,
20};
21use anyhow::{Context as _, Result};
22use std::sync::OnceLock;
23
24const KEEP_PREVIOUS: &str = "KEEP_PREVIOUS";
25
26/// Print a summary report of repair results across all examples.
27pub fn print_report(examples: &[Example], confidence_threshold: u8) {
28 let total = examples.len();
29 let mut no_repair_needed = 0;
30 let mut repaired = 0;
31 let mut repair_failed = 0;
32
33 for example in examples {
34 if !needs_repair(example, confidence_threshold) {
35 no_repair_needed += 1;
36 continue;
37 }
38
39 if has_successful_repair(example) {
40 repaired += 1;
41 } else {
42 repair_failed += 1;
43 }
44 }
45
46 let needed_repair = total - no_repair_needed;
47
48 eprintln!();
49 eprintln!("Repair summary ({total} examples):");
50 eprintln!(
51 " {no_repair_needed}/{total} didn't need repair (confidence > {confidence_threshold})"
52 );
53 if needed_repair > 0 {
54 eprintln!(" {needed_repair}/{total} needed repair:");
55 if repaired > 0 {
56 eprintln!(" {repaired} repaired successfully");
57 }
58 if repair_failed > 0 {
59 eprintln!(" {repair_failed} failed to repair");
60 }
61 }
62}
63
64/// Arguments for the repair command.
65#[derive(Debug, Clone, clap::Args)]
66pub struct RepairArgs {
67 /// Use synchronous API instead of batch
68 #[clap(long)]
69 pub no_batch: bool,
70
71 /// Confidence threshold: repair predictions with confidence <= this value (1-5)
72 #[clap(long, default_value = "2")]
73 pub confidence_threshold: u8,
74
75 /// Which LLM provider to use (anthropic or openai)
76 #[clap(long, default_value = "anthropic")]
77 pub backend: BatchProvider,
78}
79
80fn model_for_backend(backend: BatchProvider) -> &'static str {
81 match backend {
82 BatchProvider::Anthropic => "claude-sonnet-4-6",
83 BatchProvider::Openai => "gpt-5.2",
84 }
85}
86
87/// Build the quality feedback string from QA results.
88fn build_qa_feedback(example: &Example) -> Option<String> {
89 let qa = example.qa.first()?.as_ref()?;
90
91 let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
92 let reverts_edits = qa
93 .reverts_edits
94 .map_or("unknown", |v| if v { "yes" } else { "no" });
95 let confidence = qa
96 .confidence
97 .map_or("unknown".to_string(), |v| v.to_string());
98
99 Some(format!(
100 "- **Reverts user edits**: {reverts_edits}\n\
101 - **Confidence score**: {confidence}/5\n\
102 - **Reasoning**: {qa_reasoning}"
103 ))
104}
105
106/// Build the quality feedback string from computed scores when QA is unavailable.
107fn build_score_feedback(example: &Example) -> Option<String> {
108 let score = example.score.first()?;
109
110 let mut issues = Vec::new();
111
112 if score.reversal_ratio > 0.9 {
113 issues.push(format!(
114 "Automated analysis detected a high reversal ratio ({:.2}), which suggests this \
115 prediction may be reverting changes the user intentionally made. Double-check that \
116 the prediction doesn't undo the user's recent edits. If the prediction is actually \
117 fine and the edits are intentional completions rather than reversals, keep it as-is. \
118 If it truly reverts the user's changes, generate an improved prediction that \
119 continues the user's intent instead.",
120 score.reversal_ratio
121 ));
122 }
123
124 if score.wrong_editable_region == Some(true) {
125 issues.push(
126 "Automated analysis detected that the prediction may be modifying code outside \
127 the expected editable region, or producing changes misaligned with the editable \
128 region boundaries. Make sure the prediction only modifies code within the editable \
129 region and is properly aligned."
130 .to_string(),
131 );
132 }
133
134 if issues.is_empty() {
135 return None;
136 }
137
138 let mut feedback = String::from(
139 "No human quality assessment is available, but automated scoring flagged potential issues:\n\n",
140 );
141 for issue in &issues {
142 feedback.push_str(&format!("- {issue}\n"));
143 }
144 feedback.push_str(
145 "\nRemember: if the previous prediction was actually correct, output `KEEP_PREVIOUS`. \
146 If no edits should be made at all and you are unsure how to improve it, output `NO_EDITS`.",
147 );
148
149 Some(feedback)
150}
151
152/// Build the repair message (Turn 3) for a multi-turn conversation.
153///
154/// This message is sent after the original teacher prompt (Turn 1) and
155/// teacher response (Turn 2) to request an improved prediction.
156pub fn build_repair_message(example: &Example) -> Result<String> {
157 let prediction = example
158 .predictions
159 .first()
160 .context("no predictions available")?;
161 let actual_patch = prediction
162 .actual_patch
163 .as_ref()
164 .context("no actual_patch available (run predict first)")?;
165
166 let quality_feedback = build_qa_feedback(example)
167 .or_else(|| build_score_feedback(example))
168 .context("no quality feedback available (need either QA results or computed scores)")?;
169
170 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
171
172 let token_counts = count_patch_token_changes(actual_patch);
173 let mut token_change_info = format!(
174 "\n## Token Change Statistics\n\n\
175 - **Deleted tokens**: {}\n\
176 - **Inserted tokens**: {}",
177 token_counts.deleted_tokens, token_counts.inserted_tokens,
178 );
179 if token_counts.deleted_tokens > 100 || token_counts.inserted_tokens > 100 {
180 token_change_info.push_str(
181 "\n\n> **Note:** The token change count is high. \
182 Consider producing a more scoped edit that targets only the lines \
183 that truly need to change, rather than rewriting large sections.",
184 );
185 }
186
187 let prompt_template = crate::prompt_assets::get_prompt("repair.md");
188 Ok(prompt_template
189 .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
190 .replace("{quality_feedback}", &quality_feedback)
191 .replace("{token_change_info}", &token_change_info))
192}
193
194/// Check if an example needs repair based on QA feedback or computed scores.
195pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
196 // Check QA-based signals first.
197 if let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) {
198 if qa.reverts_edits == Some(true) {
199 return true;
200 }
201
202 if let Some(confidence) = qa.confidence {
203 if confidence <= confidence_threshold {
204 return true;
205 }
206 }
207
208 return false;
209 }
210
211 // When QA is unavailable, fall back to computed score signals.
212 if let Some(score) = example.score.first() {
213 if score.reversal_ratio > 0.9 {
214 return true;
215 }
216
217 if score.wrong_editable_region == Some(true) {
218 return true;
219 }
220 }
221
222 false
223}
224
225/// Parse repair model output into a patch and optional cursor.
226///
227/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
228/// and delegates normal output to `TeacherPrompt::parse`.
229pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
230 if actual_output.contains(KEEP_PREVIOUS) {
231 let original = example
232 .predictions
233 .first()
234 .context("no original prediction to keep")?;
235 let patch = original.actual_patch.clone().unwrap_or_default();
236 let cursor = original.actual_cursor.clone();
237 return Ok((patch, cursor));
238 }
239
240 TeacherPrompt::parse(example, actual_output)
241}
242
243/// Check if an example already has a successful repair prediction.
244fn has_successful_repair(example: &Example) -> bool {
245 example
246 .predictions
247 .iter()
248 .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
249}
250
251static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
252static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
253static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
254static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
255
256/// Run repair for a single example.
257///
258/// This sends a multi-turn conversation to the LLM:
259/// - Turn 1 (User): Original teacher prompt
260/// - Turn 2 (Assistant): Original teacher response
261/// - Turn 3 (User): Repair critique and instructions
262/// - Turn 4 (Assistant): Improved prediction (the response we parse)
263pub async fn run_repair(
264 example: &mut Example,
265 args: &RepairArgs,
266 example_progress: &ExampleProgress,
267) -> Result<()> {
268 if has_successful_repair(example) {
269 return Ok(());
270 }
271
272 if !needs_repair(example, args.confidence_threshold) {
273 return Ok(());
274 }
275
276 run_parse_output(example).context("Failed to execute run_parse_output")?;
277
278 if example.prompt_inputs.is_none() {
279 anyhow::bail!("prompt_inputs missing (run context retrieval first)");
280 }
281
282 if example.predictions.is_empty() {
283 anyhow::bail!("no predictions available (run predict first)");
284 }
285
286 let teacher_prompt = example
287 .prompt
288 .as_ref()
289 .context("prompt missing (run format_prompt first)")?;
290
291 let teacher_response = &example.predictions[0].actual_output;
292 if teacher_response.is_empty() {
293 anyhow::bail!("teacher response is empty (run predict first)");
294 }
295
296 let step_progress = example_progress.start(Step::Repair);
297
298 let model = model_for_backend(args.backend);
299 let repair_message = build_repair_message(example).context("Failed to build repair message")?;
300
301 step_progress.set_substatus("generating");
302
303 let response = match args.backend {
304 BatchProvider::Anthropic => {
305 let client = if args.no_batch {
306 ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
307 AnthropicClient::plain().expect("Failed to create Anthropic client")
308 })
309 } else {
310 ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
311 AnthropicClient::batch(&LLM_CACHE_DB)
312 .expect("Failed to create Anthropic client")
313 })
314 };
315
316 let messages = vec![
317 // Turn 1: Original teacher prompt
318 anthropic::Message {
319 role: anthropic::Role::User,
320 content: vec![anthropic::RequestContent::Text {
321 text: teacher_prompt.input.clone(),
322 cache_control: None,
323 }],
324 },
325 // Turn 2: Original teacher response
326 anthropic::Message {
327 role: anthropic::Role::Assistant,
328 content: vec![anthropic::RequestContent::Text {
329 text: teacher_response.clone(),
330 cache_control: None,
331 }],
332 },
333 // Turn 3: Repair critique and instructions
334 anthropic::Message {
335 role: anthropic::Role::User,
336 content: vec![anthropic::RequestContent::Text {
337 text: repair_message,
338 cache_control: None,
339 }],
340 },
341 ];
342
343 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
344 return Ok(());
345 };
346
347 response
348 .content
349 .iter()
350 .filter_map(|c| match c {
351 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
352 _ => None,
353 })
354 .collect::<Vec<_>>()
355 .join("")
356 }
357 BatchProvider::Openai => {
358 let client = if args.no_batch {
359 OPENAI_CLIENT_PLAIN
360 .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
361 } else {
362 OPENAI_CLIENT_BATCH.get_or_init(|| {
363 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
364 })
365 };
366
367 let messages = vec![
368 // Turn 1: Original teacher prompt
369 open_ai::RequestMessage::User {
370 content: open_ai::MessageContent::Plain(teacher_prompt.input.clone()),
371 },
372 // Turn 2: Original teacher response
373 open_ai::RequestMessage::Assistant {
374 content: Some(open_ai::MessageContent::Plain(teacher_response.clone())),
375 tool_calls: vec![],
376 },
377 // Turn 3: Repair critique and instructions
378 open_ai::RequestMessage::User {
379 content: open_ai::MessageContent::Plain(repair_message),
380 },
381 ];
382
383 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
384 return Ok(());
385 };
386
387 response
388 .choices
389 .into_iter()
390 .filter_map(|choice| match choice.message {
391 open_ai::RequestMessage::Assistant { content, .. } => {
392 content.map(|c| match c {
393 open_ai::MessageContent::Plain(text) => text,
394 open_ai::MessageContent::Multipart(parts) => parts
395 .into_iter()
396 .filter_map(|p| match p {
397 open_ai::MessagePart::Text { text } => Some(text),
398 _ => None,
399 })
400 .collect::<Vec<_>>()
401 .join(""),
402 })
403 }
404 _ => None,
405 })
406 .collect::<Vec<_>>()
407 .join("")
408 }
409 };
410
411 let parse_result = parse(example, &response);
412 let err = parse_result
413 .as_ref()
414 .err()
415 .map(|e| format!("Failed to parse repair response: {}", e));
416
417 let (actual_patch, actual_cursor) = parse_result.ok().unzip();
418 let actual_cursor = actual_cursor.flatten();
419
420 example.predictions.push(ExamplePrediction {
421 actual_patch,
422 actual_output: response,
423 actual_cursor,
424 error: err,
425 provider: PredictionProvider::Repair,
426 cumulative_logprob: None,
427 avg_logprob: None,
428 });
429
430 Ok(())
431}
432
433/// Sync batches for repair (upload pending requests, download finished results).
434pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
435 if args.no_batch {
436 return Ok(());
437 }
438
439 match args.backend {
440 BatchProvider::Anthropic => {
441 let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
442 AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
443 });
444 client.sync_batches().await?;
445 }
446 BatchProvider::Openai => {
447 let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
448 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
449 });
450 client.sync_batches().await?;
451 }
452 }
453
454 Ok(())
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::{PredictionProvider, TeacherBackend};
461 use edit_prediction::example_spec::ExampleSpec;
462 use std::{path::Path, sync::Arc};
463
464 fn example_with_previous_prediction() -> Example {
465 Example {
466 spec: ExampleSpec {
467 name: "example".to_string(),
468 repository_url: "https://github.com/zed-industries/zed.git".to_string(),
469 revision: "HEAD".to_string(),
470 tags: Vec::new(),
471 reasoning: None,
472 uncommitted_diff: String::new(),
473 cursor_path: Arc::from(Path::new("src/main.rs")),
474 cursor_position: "0:0".to_string(),
475 edit_history: String::new(),
476 expected_patches: Vec::new(),
477 rejected_patch: None,
478 telemetry: None,
479 human_feedback: Vec::new(),
480 rating: None,
481 },
482 prompt_inputs: None,
483 prompt: None,
484 predictions: vec![ExamplePrediction {
485 actual_patch: Some("previous patch".to_string()),
486 actual_output: String::new(),
487 actual_cursor: Some(ActualCursor {
488 path: "src/main.rs".to_string(),
489 row: 1,
490 column: 2,
491 offset: 3,
492 editable_region_offset: Some(4),
493 }),
494 error: None,
495 provider: PredictionProvider::Teacher(TeacherBackend::Sonnet45),
496 cumulative_logprob: None,
497 avg_logprob: None,
498 }],
499 score: Vec::new(),
500 qa: Vec::new(),
501 zed_version: None,
502 state: None,
503 }
504 }
505
506 #[test]
507 fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
508 let example = example_with_previous_prediction();
509 let actual_output = indoc::indoc! {"
510 After reviewing the feedback, the previous prediction is still correct.
511 Use `KEEP_PREVIOUS`.
512
513 ```
514 unrelated trailing code block
515 ```
516 "};
517
518 let (patch, cursor) = parse(&example, actual_output).unwrap();
519
520 assert_eq!(patch, "previous patch");
521 assert_eq!(cursor.unwrap().offset, 3);
522 }
523}