1//! Repair predictions that received poor quality signals.
2//!
3//! This module takes examples with predictions, identifies predictions that need
4//! improvement, and uses an LLM to generate improved predictions. It supports
5//! two sources of quality signals:
6//! - QA feedback (reverts_edits or low confidence)
7//! - Computed scores when QA is unavailable (high reversal_ratio or wrong_editable_region)
8
9use crate::{
10 BatchProvider, PredictionProvider,
11 anthropic_client::AnthropicClient,
12 example::{ActualCursor, Example, ExamplePrediction},
13 format_prompt::TeacherPrompt,
14 metrics::count_patch_token_changes,
15 openai_client::OpenAiClient,
16 parse_output::run_parse_output,
17 paths::LLM_CACHE_DB,
18 progress::{ExampleProgress, Progress, Step},
19 word_diff::unified_to_word_diff,
20};
21use anyhow::{Context as _, Result};
22use std::sync::OnceLock;
23
24const KEEP_PREVIOUS: &str = "KEEP_PREVIOUS";
25
26/// Print a summary report of repair results across all examples.
27pub fn print_report(examples: &[Example], confidence_threshold: u8) {
28 let total = examples.len();
29 let mut no_repair_needed = 0;
30 let mut repaired = 0;
31 let mut repair_failed = 0;
32
33 for example in examples {
34 if !needs_repair(example, confidence_threshold) {
35 no_repair_needed += 1;
36 continue;
37 }
38
39 if has_successful_repair(example) {
40 repaired += 1;
41 } else {
42 repair_failed += 1;
43 }
44 }
45
46 let needed_repair = total - no_repair_needed;
47
48 eprintln!();
49 eprintln!("Repair summary ({total} examples):");
50 eprintln!(
51 " {no_repair_needed}/{total} didn't need repair (confidence > {confidence_threshold})"
52 );
53 if needed_repair > 0 {
54 eprintln!(" {needed_repair}/{total} needed repair:");
55 if repaired > 0 {
56 eprintln!(" {repaired} repaired successfully");
57 }
58 if repair_failed > 0 {
59 eprintln!(" {repair_failed} failed to repair");
60 }
61 }
62}
63
64/// Arguments for the repair command.
65#[derive(Debug, Clone, clap::Args)]
66pub struct RepairArgs {
67 /// Use synchronous API instead of batch
68 #[clap(long)]
69 pub no_batch: bool,
70
71 /// Confidence threshold: repair predictions with confidence <= this value (1-5)
72 #[clap(long, default_value = "2")]
73 pub confidence_threshold: u8,
74
75 /// Which LLM provider to use (anthropic or openai)
76 #[clap(long, default_value = "anthropic")]
77 pub backend: BatchProvider,
78 /// Wait for all batches to complete before exiting
79 #[clap(long)]
80 pub wait: bool,
81}
82
83fn model_for_backend(backend: BatchProvider) -> &'static str {
84 match backend {
85 BatchProvider::Anthropic => "claude-sonnet-4-6",
86 BatchProvider::Openai => "gpt-5.2",
87 }
88}
89
90/// Build the quality feedback string from QA results.
91fn build_qa_feedback(example: &Example) -> Option<String> {
92 let qa = example.qa.first()?.as_ref()?;
93
94 let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
95 let reverts_edits = qa
96 .reverts_edits
97 .map_or("unknown", |v| if v { "yes" } else { "no" });
98 let confidence = qa
99 .confidence
100 .map_or("unknown".to_string(), |v| v.to_string());
101
102 Some(format!(
103 "- **Reverts user edits**: {reverts_edits}\n\
104 - **Confidence score**: {confidence}/5\n\
105 - **Reasoning**: {qa_reasoning}"
106 ))
107}
108
109/// Build the quality feedback string from computed scores when QA is unavailable.
110fn build_score_feedback(example: &Example) -> Option<String> {
111 let score = example.score.first()?;
112
113 let mut issues = Vec::new();
114
115 if score.reversal_ratio > 0.9 {
116 issues.push(format!(
117 "Automated analysis detected a high reversal ratio ({:.2}), which suggests this \
118 prediction may be reverting changes the user intentionally made. Double-check that \
119 the prediction doesn't undo the user's recent edits. If the prediction is actually \
120 fine and the edits are intentional completions rather than reversals, keep it as-is. \
121 If it truly reverts the user's changes, generate an improved prediction that \
122 continues the user's intent instead.",
123 score.reversal_ratio
124 ));
125 }
126
127 if score.wrong_editable_region == Some(true) {
128 issues.push(
129 "Automated analysis detected that the prediction may be modifying code outside \
130 the expected editable region, or producing changes misaligned with the editable \
131 region boundaries. Make sure the prediction only modifies code within the editable \
132 region and is properly aligned."
133 .to_string(),
134 );
135 }
136
137 if score.discarded_chars.unwrap_or(0) > 80 && score.exact_lines_fp > 5 {
138 issues.push(
139 "Automated analysis detected that this prediction might be too large or speculative. \
140 Please review it and think if we should keep it or generate a more focused prediction. \
141 Examples of more focused predictions: \
142 - Predicting a function outline but not its body. \
143 - Predicting only the first logical step and not speculating about further steps.
144 In general, the smaller the prediction you make, the higher the chance it will be correct."
145 .to_string(),
146 );
147 }
148
149 if issues.is_empty() {
150 return None;
151 }
152
153 let mut feedback = String::from(
154 "No human quality assessment is available, but automated scoring flagged potential issues:\n\n",
155 );
156 for issue in &issues {
157 feedback.push_str(&format!("- {issue}\n"));
158 }
159 feedback.push_str(
160 "\nRemember: if the previous prediction was actually correct, output `KEEP_PREVIOUS`. \
161 If no edits should be made at all and you are unsure how to improve it, output `NO_EDITS`.",
162 );
163
164 Some(feedback)
165}
166
167/// Build the repair message (Turn 3) for a multi-turn conversation.
168///
169/// This message is sent after the original teacher prompt (Turn 1) and
170/// teacher response (Turn 2) to request an improved prediction.
171pub fn build_repair_message(example: &Example) -> Result<String> {
172 let prediction = example
173 .predictions
174 .first()
175 .context("no predictions available")?;
176 let actual_patch = prediction
177 .actual_patch
178 .as_ref()
179 .context("no actual_patch available (run predict first)")?;
180
181 let quality_feedback = build_qa_feedback(example)
182 .or_else(|| build_score_feedback(example))
183 .context("no quality feedback available (need either QA results or computed scores)")?;
184
185 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
186
187 let token_counts = count_patch_token_changes(actual_patch);
188 let mut token_change_info = format!(
189 "\n## Token Change Statistics\n\n\
190 - **Deleted tokens**: {}\n\
191 - **Inserted tokens**: {}",
192 token_counts.deleted_tokens, token_counts.inserted_tokens,
193 );
194 if token_counts.deleted_tokens > 100 || token_counts.inserted_tokens > 100 {
195 token_change_info.push_str(
196 "\n\n> **Note:** The token change count is high. \
197 Consider producing a more scoped edit that targets only the lines \
198 that truly need to change, rather than rewriting large sections.",
199 );
200 }
201
202 let prompt_template = crate::prompt_assets::get_prompt("repair.md");
203 Ok(prompt_template
204 .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
205 .replace("{quality_feedback}", &quality_feedback)
206 .replace("{token_change_info}", &token_change_info))
207}
208
209/// Check if an example needs repair based on QA feedback or computed scores.
210pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
211 // Check QA-based signals first.
212 if let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) {
213 if qa.reverts_edits == Some(true) {
214 return true;
215 }
216
217 if let Some(confidence) = qa.confidence {
218 if confidence <= confidence_threshold {
219 return true;
220 }
221 }
222
223 return false;
224 }
225
226 // When QA is unavailable, fall back to computed score signals.
227 if let Some(score) = example.score.first() {
228 if score.reversal_ratio > 0.9 {
229 return true;
230 }
231
232 if score.wrong_editable_region == Some(true) {
233 return true;
234 }
235 }
236
237 false
238}
239
240/// Parse repair model output into a patch and optional cursor.
241///
242/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
243/// and delegates normal output to `TeacherPrompt::parse`.
244pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
245 if actual_output.contains(KEEP_PREVIOUS) {
246 let original = example
247 .predictions
248 .first()
249 .context("no original prediction to keep")?;
250 let patch = original.actual_patch.clone().unwrap_or_default();
251 let cursor = original.actual_cursor.clone();
252 return Ok((patch, cursor));
253 }
254
255 TeacherPrompt::parse(example, actual_output)
256}
257
258/// Check if an example already has a successful repair prediction.
259fn has_successful_repair(example: &Example) -> bool {
260 example
261 .predictions
262 .iter()
263 .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
264}
265
266static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
267static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
268static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
269static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
270
271/// Run repair for a single example.
272///
273/// This sends a multi-turn conversation to the LLM:
274/// - Turn 1 (User): Original teacher prompt
275/// - Turn 2 (Assistant): Original teacher response
276/// - Turn 3 (User): Repair critique and instructions
277/// - Turn 4 (Assistant): Improved prediction (the response we parse)
278pub async fn run_repair(
279 example: &mut Example,
280 args: &RepairArgs,
281 example_progress: &ExampleProgress,
282) -> Result<()> {
283 if has_successful_repair(example) {
284 return Ok(());
285 }
286
287 if !needs_repair(example, args.confidence_threshold) {
288 return Ok(());
289 }
290
291 run_parse_output(example).context("Failed to execute run_parse_output")?;
292
293 if example.prompt_inputs.is_none() {
294 anyhow::bail!("prompt_inputs missing (run context retrieval first)");
295 }
296
297 if example.predictions.is_empty() {
298 anyhow::bail!("no predictions available (run predict first)");
299 }
300
301 let teacher_prompt = example
302 .prompt
303 .as_ref()
304 .context("prompt missing (run format_prompt first)")?;
305
306 let teacher_response = &example.predictions[0].actual_output;
307 if teacher_response.is_empty() {
308 anyhow::bail!("teacher response is empty (run predict first)");
309 }
310
311 let step_progress = example_progress.start(Step::Repair);
312
313 let model = model_for_backend(args.backend);
314 let repair_message = build_repair_message(example).context("Failed to build repair message")?;
315
316 step_progress.set_substatus("generating");
317
318 let response = match args.backend {
319 BatchProvider::Anthropic => {
320 let client = if args.no_batch {
321 ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
322 AnthropicClient::plain().expect("Failed to create Anthropic client")
323 })
324 } else {
325 ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
326 AnthropicClient::batch(&LLM_CACHE_DB)
327 .expect("Failed to create Anthropic client")
328 })
329 };
330
331 let messages = vec![
332 // Turn 1: Original teacher prompt
333 anthropic::Message {
334 role: anthropic::Role::User,
335 content: vec![anthropic::RequestContent::Text {
336 text: teacher_prompt.input.clone(),
337 cache_control: None,
338 }],
339 },
340 // Turn 2: Original teacher response
341 anthropic::Message {
342 role: anthropic::Role::Assistant,
343 content: vec![anthropic::RequestContent::Text {
344 text: teacher_response.clone(),
345 cache_control: None,
346 }],
347 },
348 // Turn 3: Repair critique and instructions
349 anthropic::Message {
350 role: anthropic::Role::User,
351 content: vec![anthropic::RequestContent::Text {
352 text: repair_message,
353 cache_control: None,
354 }],
355 },
356 ];
357
358 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
359 return Ok(());
360 };
361
362 response
363 .content
364 .iter()
365 .filter_map(|c| match c {
366 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
367 _ => None,
368 })
369 .collect::<Vec<_>>()
370 .join("")
371 }
372 BatchProvider::Openai => {
373 let client = if args.no_batch {
374 OPENAI_CLIENT_PLAIN
375 .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
376 } else {
377 OPENAI_CLIENT_BATCH.get_or_init(|| {
378 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
379 })
380 };
381
382 let messages = vec![
383 // Turn 1: Original teacher prompt
384 open_ai::RequestMessage::User {
385 content: open_ai::MessageContent::Plain(teacher_prompt.input.clone()),
386 },
387 // Turn 2: Original teacher response
388 open_ai::RequestMessage::Assistant {
389 content: Some(open_ai::MessageContent::Plain(teacher_response.clone())),
390 tool_calls: vec![],
391 reasoning_content: None,
392 },
393 // Turn 3: Repair critique and instructions
394 open_ai::RequestMessage::User {
395 content: open_ai::MessageContent::Plain(repair_message),
396 },
397 ];
398
399 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
400 return Ok(());
401 };
402
403 response
404 .choices
405 .into_iter()
406 .filter_map(|choice| match choice.message {
407 open_ai::RequestMessage::Assistant { content, .. } => {
408 content.map(|c| match c {
409 open_ai::MessageContent::Plain(text) => text,
410 open_ai::MessageContent::Multipart(parts) => parts
411 .into_iter()
412 .filter_map(|p| match p {
413 open_ai::MessagePart::Text { text } => Some(text),
414 _ => None,
415 })
416 .collect::<Vec<_>>()
417 .join(""),
418 })
419 }
420 _ => None,
421 })
422 .collect::<Vec<_>>()
423 .join("")
424 }
425 };
426
427 let parse_result = parse(example, &response);
428 let err = parse_result
429 .as_ref()
430 .err()
431 .map(|e| format!("Failed to parse repair response: {}", e));
432
433 let (actual_patch, actual_cursor) = parse_result.ok().unzip();
434 let actual_cursor = actual_cursor.flatten();
435
436 example.predictions.push(ExamplePrediction {
437 actual_patch,
438 actual_output: response,
439 actual_cursor,
440 error: err,
441 provider: PredictionProvider::Repair,
442 cumulative_logprob: None,
443 avg_logprob: None,
444 });
445
446 Ok(())
447}
448
449/// Sync batches for repair (upload pending requests, download finished results).
450pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
451 if args.no_batch {
452 return Ok(());
453 }
454
455 match args.backend {
456 BatchProvider::Anthropic => {
457 let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
458 AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
459 });
460 client.sync_batches().await?;
461 }
462 BatchProvider::Openai => {
463 let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
464 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
465 });
466 client.sync_batches().await?;
467 }
468 }
469
470 Ok(())
471}
472
473pub async fn reprocess_after_batch_wait(examples: &mut [Example], args: &RepairArgs) -> Result<()> {
474 let mut reprocessed = 0;
475 for example in examples.iter_mut() {
476 if has_successful_repair(example) || !needs_repair(example, args.confidence_threshold) {
477 continue;
478 }
479
480 let example_progress = Progress::global().start_group(&example.spec.name);
481 run_repair(example, args, &example_progress).await?;
482 reprocessed += 1;
483 }
484
485 if reprocessed > 0 {
486 eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
487 }
488
489 Ok(())
490}
491
492pub async fn wait_for_batches(args: &RepairArgs) -> Result<()> {
493 if args.no_batch {
494 return Ok(());
495 }
496
497 let poll_interval = std::time::Duration::from_secs(30);
498
499 loop {
500 let pending = pending_batch_count(args)?;
501 if pending == 0 {
502 break;
503 }
504
505 eprintln!(
506 "Waiting for {} pending repair batch request(s) to complete... (polling every {}s)",
507 pending,
508 poll_interval.as_secs()
509 );
510 std::thread::sleep(poll_interval);
511
512 sync_batches(args).await?;
513 }
514
515 Ok(())
516}
517
518fn pending_batch_count(args: &RepairArgs) -> Result<usize> {
519 match args.backend {
520 BatchProvider::Anthropic => {
521 let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
522 AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
523 });
524 client.pending_batch_count()
525 }
526 BatchProvider::Openai => {
527 let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
528 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
529 });
530 client.pending_batch_count()
531 }
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use crate::{PredictionProvider, TeacherBackend};
539 use edit_prediction::example_spec::ExampleSpec;
540 use std::{path::Path, sync::Arc};
541 use zeta_prompt::ZetaFormat;
542
543 fn example_with_previous_prediction() -> Example {
544 Example {
545 spec: ExampleSpec {
546 name: "example".to_string(),
547 repository_url: "https://github.com/zed-industries/zed.git".to_string(),
548 revision: "HEAD".to_string(),
549 tags: Vec::new(),
550 reasoning: None,
551 uncommitted_diff: String::new(),
552 cursor_path: Arc::from(Path::new("src/main.rs")),
553 cursor_position: "0:0".to_string(),
554 edit_history: String::new(),
555 expected_patches: Vec::new(),
556 rejected_patch: None,
557 telemetry: None,
558 human_feedback: Vec::new(),
559 rating: None,
560 },
561 prompt_inputs: None,
562 prompt: None,
563 predictions: vec![ExamplePrediction {
564 actual_patch: Some("previous patch".to_string()),
565 actual_output: String::new(),
566 actual_cursor: Some(ActualCursor {
567 path: "src/main.rs".to_string(),
568 row: 1,
569 column: 2,
570 offset: 3,
571 editable_region_offset: Some(4),
572 }),
573 error: None,
574 provider: PredictionProvider::Teacher(
575 TeacherBackend::Sonnet45,
576 ZetaFormat::default(),
577 ),
578 cumulative_logprob: None,
579 avg_logprob: None,
580 }],
581 score: Vec::new(),
582 qa: Vec::new(),
583 zed_version: None,
584 state: None,
585 }
586 }
587
588 #[test]
589 fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
590 let example = example_with_previous_prediction();
591 let actual_output = indoc::indoc! {"
592 After reviewing the feedback, the previous prediction is still correct.
593 Use `KEEP_PREVIOUS`.
594
595 ```
596 unrelated trailing code block
597 ```
598 "};
599
600 let (patch, cursor) = parse(&example, actual_output).unwrap();
601
602 assert_eq!(patch, "previous patch");
603 assert_eq!(cursor.unwrap().offset, 3);
604 }
605}