1//! Repair predictions that received poor quality signals.
2//!
3//! This module takes examples with predictions, identifies predictions that need
4//! improvement, and uses an LLM to generate improved predictions. It supports
5//! two sources of quality signals:
6//! - QA feedback (reverts_edits or low confidence)
7//! - Computed scores when QA is unavailable (high reversal_ratio or wrong_editable_region)
8
9use crate::{
10 BatchProvider, PredictionProvider,
11 anthropic_client::AnthropicClient,
12 example::{ActualCursor, Example, ExamplePrediction},
13 format_prompt::TeacherPrompt,
14 metrics::count_patch_token_changes,
15 openai_client::OpenAiClient,
16 parse_output::run_parse_output,
17 paths::LLM_CACHE_DB,
18 progress::{ExampleProgress, Progress, Step},
19 word_diff::unified_to_word_diff,
20};
21use anyhow::{Context as _, Result};
22use std::sync::OnceLock;
23
24const KEEP_PREVIOUS: &str = "KEEP_PREVIOUS";
25
26/// Print a summary report of repair results across all examples.
27pub fn print_report(examples: &[Example], confidence_threshold: u8) {
28 let total = examples.len();
29 let mut no_repair_needed = 0;
30 let mut repaired = 0;
31 let mut repair_failed = 0;
32
33 for example in examples {
34 if !needs_repair(example, confidence_threshold) {
35 no_repair_needed += 1;
36 continue;
37 }
38
39 if has_successful_repair(example) {
40 repaired += 1;
41 } else {
42 repair_failed += 1;
43 }
44 }
45
46 let needed_repair = total - no_repair_needed;
47
48 eprintln!();
49 eprintln!("Repair summary ({total} examples):");
50 eprintln!(
51 " {no_repair_needed}/{total} didn't need repair (confidence > {confidence_threshold})"
52 );
53 if needed_repair > 0 {
54 eprintln!(" {needed_repair}/{total} needed repair:");
55 if repaired > 0 {
56 eprintln!(" {repaired} repaired successfully");
57 }
58 if repair_failed > 0 {
59 eprintln!(" {repair_failed} failed to repair");
60 }
61 }
62}
63
64/// Arguments for the repair command.
65#[derive(Debug, Clone, clap::Args)]
66pub struct RepairArgs {
67 /// Use synchronous API instead of batch
68 #[clap(long)]
69 pub no_batch: bool,
70
71 /// Confidence threshold: repair predictions with confidence <= this value (1-5)
72 #[clap(long, default_value = "2")]
73 pub confidence_threshold: u8,
74
75 /// Which LLM provider to use (anthropic or openai)
76 #[clap(long, default_value = "anthropic")]
77 pub backend: BatchProvider,
78 /// Wait for all batches to complete before exiting
79 #[clap(long)]
80 pub wait: bool,
81}
82
83fn model_for_backend(backend: BatchProvider) -> &'static str {
84 match backend {
85 BatchProvider::Anthropic => "claude-sonnet-4-6",
86 BatchProvider::Openai => "gpt-5.2",
87 }
88}
89
90/// Build the quality feedback string from QA results.
91fn build_qa_feedback(example: &Example) -> Option<String> {
92 let qa = example.qa.first()?.as_ref()?;
93
94 let qa_reasoning = qa.reasoning.as_deref().unwrap_or("No reasoning provided");
95 let reverts_edits = qa
96 .reverts_edits
97 .map_or("unknown", |v| if v { "yes" } else { "no" });
98 let confidence = qa
99 .confidence
100 .map_or("unknown".to_string(), |v| v.to_string());
101
102 Some(format!(
103 "- **Reverts user edits**: {reverts_edits}\n\
104 - **Confidence score**: {confidence}/5\n\
105 - **Reasoning**: {qa_reasoning}"
106 ))
107}
108
109/// Build the quality feedback string from computed scores when QA is unavailable.
110fn build_score_feedback(example: &Example) -> Option<String> {
111 let score = example.score.first()?;
112
113 let mut issues = Vec::new();
114
115 if score.reversal_ratio > 0.9 {
116 issues.push(format!(
117 "Automated analysis detected a high reversal ratio ({:.2}), which suggests this \
118 prediction may be reverting changes the user intentionally made. Double-check that \
119 the prediction doesn't undo the user's recent edits. If the prediction is actually \
120 fine and the edits are intentional completions rather than reversals, keep it as-is. \
121 If it truly reverts the user's changes, generate an improved prediction that \
122 continues the user's intent instead.",
123 score.reversal_ratio
124 ));
125 }
126
127 if score.wrong_editable_region == Some(true) {
128 issues.push(
129 "Automated analysis detected that the prediction may be modifying code outside \
130 the expected editable region, or producing changes misaligned with the editable \
131 region boundaries. Make sure the prediction only modifies code within the editable \
132 region and is properly aligned."
133 .to_string(),
134 );
135 }
136
137 if issues.is_empty() {
138 return None;
139 }
140
141 let mut feedback = String::from(
142 "No human quality assessment is available, but automated scoring flagged potential issues:\n\n",
143 );
144 for issue in &issues {
145 feedback.push_str(&format!("- {issue}\n"));
146 }
147 feedback.push_str(
148 "\nRemember: if the previous prediction was actually correct, output `KEEP_PREVIOUS`. \
149 If no edits should be made at all and you are unsure how to improve it, output `NO_EDITS`.",
150 );
151
152 Some(feedback)
153}
154
155/// Build the repair message (Turn 3) for a multi-turn conversation.
156///
157/// This message is sent after the original teacher prompt (Turn 1) and
158/// teacher response (Turn 2) to request an improved prediction.
159pub fn build_repair_message(example: &Example) -> Result<String> {
160 let prediction = example
161 .predictions
162 .first()
163 .context("no predictions available")?;
164 let actual_patch = prediction
165 .actual_patch
166 .as_ref()
167 .context("no actual_patch available (run predict first)")?;
168
169 let quality_feedback = build_qa_feedback(example)
170 .or_else(|| build_score_feedback(example))
171 .context("no quality feedback available (need either QA results or computed scores)")?;
172
173 let actual_patch_word_diff = unified_to_word_diff(actual_patch);
174
175 let token_counts = count_patch_token_changes(actual_patch);
176 let mut token_change_info = format!(
177 "\n## Token Change Statistics\n\n\
178 - **Deleted tokens**: {}\n\
179 - **Inserted tokens**: {}",
180 token_counts.deleted_tokens, token_counts.inserted_tokens,
181 );
182 if token_counts.deleted_tokens > 100 || token_counts.inserted_tokens > 100 {
183 token_change_info.push_str(
184 "\n\n> **Note:** The token change count is high. \
185 Consider producing a more scoped edit that targets only the lines \
186 that truly need to change, rather than rewriting large sections.",
187 );
188 }
189
190 let prompt_template = crate::prompt_assets::get_prompt("repair.md");
191 Ok(prompt_template
192 .replace("{actual_patch_word_diff}", &actual_patch_word_diff)
193 .replace("{quality_feedback}", &quality_feedback)
194 .replace("{token_change_info}", &token_change_info))
195}
196
197/// Check if an example needs repair based on QA feedback or computed scores.
198pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
199 // Check QA-based signals first.
200 if let Some(qa) = example.qa.first().and_then(|q| q.as_ref()) {
201 if qa.reverts_edits == Some(true) {
202 return true;
203 }
204
205 if let Some(confidence) = qa.confidence {
206 if confidence <= confidence_threshold {
207 return true;
208 }
209 }
210
211 return false;
212 }
213
214 // When QA is unavailable, fall back to computed score signals.
215 if let Some(score) = example.score.first() {
216 if score.reversal_ratio > 0.9 {
217 return true;
218 }
219
220 if score.wrong_editable_region == Some(true) {
221 return true;
222 }
223 }
224
225 false
226}
227
228/// Parse repair model output into a patch and optional cursor.
229///
230/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
231/// and delegates normal output to `TeacherPrompt::parse`.
232pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
233 if actual_output.contains(KEEP_PREVIOUS) {
234 let original = example
235 .predictions
236 .first()
237 .context("no original prediction to keep")?;
238 let patch = original.actual_patch.clone().unwrap_or_default();
239 let cursor = original.actual_cursor.clone();
240 return Ok((patch, cursor));
241 }
242
243 TeacherPrompt::parse(example, actual_output)
244}
245
246/// Check if an example already has a successful repair prediction.
247fn has_successful_repair(example: &Example) -> bool {
248 example
249 .predictions
250 .iter()
251 .any(|p| p.provider == PredictionProvider::Repair && p.actual_patch.is_some())
252}
253
254static ANTHROPIC_CLIENT_BATCH: OnceLock<AnthropicClient> = OnceLock::new();
255static ANTHROPIC_CLIENT_PLAIN: OnceLock<AnthropicClient> = OnceLock::new();
256static OPENAI_CLIENT_BATCH: OnceLock<OpenAiClient> = OnceLock::new();
257static OPENAI_CLIENT_PLAIN: OnceLock<OpenAiClient> = OnceLock::new();
258
259/// Run repair for a single example.
260///
261/// This sends a multi-turn conversation to the LLM:
262/// - Turn 1 (User): Original teacher prompt
263/// - Turn 2 (Assistant): Original teacher response
264/// - Turn 3 (User): Repair critique and instructions
265/// - Turn 4 (Assistant): Improved prediction (the response we parse)
266pub async fn run_repair(
267 example: &mut Example,
268 args: &RepairArgs,
269 example_progress: &ExampleProgress,
270) -> Result<()> {
271 if has_successful_repair(example) {
272 return Ok(());
273 }
274
275 if !needs_repair(example, args.confidence_threshold) {
276 return Ok(());
277 }
278
279 run_parse_output(example).context("Failed to execute run_parse_output")?;
280
281 if example.prompt_inputs.is_none() {
282 anyhow::bail!("prompt_inputs missing (run context retrieval first)");
283 }
284
285 if example.predictions.is_empty() {
286 anyhow::bail!("no predictions available (run predict first)");
287 }
288
289 let teacher_prompt = example
290 .prompt
291 .as_ref()
292 .context("prompt missing (run format_prompt first)")?;
293
294 let teacher_response = &example.predictions[0].actual_output;
295 if teacher_response.is_empty() {
296 anyhow::bail!("teacher response is empty (run predict first)");
297 }
298
299 let step_progress = example_progress.start(Step::Repair);
300
301 let model = model_for_backend(args.backend);
302 let repair_message = build_repair_message(example).context("Failed to build repair message")?;
303
304 step_progress.set_substatus("generating");
305
306 let response = match args.backend {
307 BatchProvider::Anthropic => {
308 let client = if args.no_batch {
309 ANTHROPIC_CLIENT_PLAIN.get_or_init(|| {
310 AnthropicClient::plain().expect("Failed to create Anthropic client")
311 })
312 } else {
313 ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
314 AnthropicClient::batch(&LLM_CACHE_DB)
315 .expect("Failed to create Anthropic client")
316 })
317 };
318
319 let messages = vec![
320 // Turn 1: Original teacher prompt
321 anthropic::Message {
322 role: anthropic::Role::User,
323 content: vec![anthropic::RequestContent::Text {
324 text: teacher_prompt.input.clone(),
325 cache_control: None,
326 }],
327 },
328 // Turn 2: Original teacher response
329 anthropic::Message {
330 role: anthropic::Role::Assistant,
331 content: vec![anthropic::RequestContent::Text {
332 text: teacher_response.clone(),
333 cache_control: None,
334 }],
335 },
336 // Turn 3: Repair critique and instructions
337 anthropic::Message {
338 role: anthropic::Role::User,
339 content: vec![anthropic::RequestContent::Text {
340 text: repair_message,
341 cache_control: None,
342 }],
343 },
344 ];
345
346 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
347 return Ok(());
348 };
349
350 response
351 .content
352 .iter()
353 .filter_map(|c| match c {
354 anthropic::ResponseContent::Text { text } => Some(text.as_str()),
355 _ => None,
356 })
357 .collect::<Vec<_>>()
358 .join("")
359 }
360 BatchProvider::Openai => {
361 let client = if args.no_batch {
362 OPENAI_CLIENT_PLAIN
363 .get_or_init(|| OpenAiClient::plain().expect("Failed to create OpenAI client"))
364 } else {
365 OPENAI_CLIENT_BATCH.get_or_init(|| {
366 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
367 })
368 };
369
370 let messages = vec![
371 // Turn 1: Original teacher prompt
372 open_ai::RequestMessage::User {
373 content: open_ai::MessageContent::Plain(teacher_prompt.input.clone()),
374 },
375 // Turn 2: Original teacher response
376 open_ai::RequestMessage::Assistant {
377 content: Some(open_ai::MessageContent::Plain(teacher_response.clone())),
378 tool_calls: vec![],
379 },
380 // Turn 3: Repair critique and instructions
381 open_ai::RequestMessage::User {
382 content: open_ai::MessageContent::Plain(repair_message),
383 },
384 ];
385
386 let Some(response) = client.generate(model, 16384, messages, None, false).await? else {
387 return Ok(());
388 };
389
390 response
391 .choices
392 .into_iter()
393 .filter_map(|choice| match choice.message {
394 open_ai::RequestMessage::Assistant { content, .. } => {
395 content.map(|c| match c {
396 open_ai::MessageContent::Plain(text) => text,
397 open_ai::MessageContent::Multipart(parts) => parts
398 .into_iter()
399 .filter_map(|p| match p {
400 open_ai::MessagePart::Text { text } => Some(text),
401 _ => None,
402 })
403 .collect::<Vec<_>>()
404 .join(""),
405 })
406 }
407 _ => None,
408 })
409 .collect::<Vec<_>>()
410 .join("")
411 }
412 };
413
414 let parse_result = parse(example, &response);
415 let err = parse_result
416 .as_ref()
417 .err()
418 .map(|e| format!("Failed to parse repair response: {}", e));
419
420 let (actual_patch, actual_cursor) = parse_result.ok().unzip();
421 let actual_cursor = actual_cursor.flatten();
422
423 example.predictions.push(ExamplePrediction {
424 actual_patch,
425 actual_output: response,
426 actual_cursor,
427 error: err,
428 provider: PredictionProvider::Repair,
429 cumulative_logprob: None,
430 avg_logprob: None,
431 });
432
433 Ok(())
434}
435
436/// Sync batches for repair (upload pending requests, download finished results).
437pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
438 if args.no_batch {
439 return Ok(());
440 }
441
442 match args.backend {
443 BatchProvider::Anthropic => {
444 let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
445 AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
446 });
447 client.sync_batches().await?;
448 }
449 BatchProvider::Openai => {
450 let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
451 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
452 });
453 client.sync_batches().await?;
454 }
455 }
456
457 Ok(())
458}
459
460pub async fn reprocess_after_batch_wait(examples: &mut [Example], args: &RepairArgs) -> Result<()> {
461 let mut reprocessed = 0;
462 for example in examples.iter_mut() {
463 if has_successful_repair(example) || !needs_repair(example, args.confidence_threshold) {
464 continue;
465 }
466
467 let example_progress = Progress::global().start_group(&example.spec.name);
468 run_repair(example, args, &example_progress).await?;
469 reprocessed += 1;
470 }
471
472 if reprocessed > 0 {
473 eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
474 }
475
476 Ok(())
477}
478
479pub async fn wait_for_batches(args: &RepairArgs) -> Result<()> {
480 if args.no_batch {
481 return Ok(());
482 }
483
484 let poll_interval = std::time::Duration::from_secs(30);
485
486 loop {
487 let pending = pending_batch_count(args)?;
488 if pending == 0 {
489 break;
490 }
491
492 eprintln!(
493 "Waiting for {} pending repair batch request(s) to complete... (polling every {}s)",
494 pending,
495 poll_interval.as_secs()
496 );
497 std::thread::sleep(poll_interval);
498
499 sync_batches(args).await?;
500 }
501
502 Ok(())
503}
504
505fn pending_batch_count(args: &RepairArgs) -> Result<usize> {
506 match args.backend {
507 BatchProvider::Anthropic => {
508 let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
509 AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
510 });
511 client.pending_batch_count()
512 }
513 BatchProvider::Openai => {
514 let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
515 OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
516 });
517 client.pending_batch_count()
518 }
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use crate::{PredictionProvider, TeacherBackend};
526 use edit_prediction::example_spec::ExampleSpec;
527 use std::{path::Path, sync::Arc};
528
529 fn example_with_previous_prediction() -> Example {
530 Example {
531 spec: ExampleSpec {
532 name: "example".to_string(),
533 repository_url: "https://github.com/zed-industries/zed.git".to_string(),
534 revision: "HEAD".to_string(),
535 tags: Vec::new(),
536 reasoning: None,
537 uncommitted_diff: String::new(),
538 cursor_path: Arc::from(Path::new("src/main.rs")),
539 cursor_position: "0:0".to_string(),
540 edit_history: String::new(),
541 expected_patches: Vec::new(),
542 rejected_patch: None,
543 telemetry: None,
544 human_feedback: Vec::new(),
545 rating: None,
546 },
547 prompt_inputs: None,
548 prompt: None,
549 predictions: vec![ExamplePrediction {
550 actual_patch: Some("previous patch".to_string()),
551 actual_output: String::new(),
552 actual_cursor: Some(ActualCursor {
553 path: "src/main.rs".to_string(),
554 row: 1,
555 column: 2,
556 offset: 3,
557 editable_region_offset: Some(4),
558 }),
559 error: None,
560 provider: PredictionProvider::Teacher(TeacherBackend::Sonnet45),
561 cumulative_logprob: None,
562 avg_logprob: None,
563 }],
564 score: Vec::new(),
565 qa: Vec::new(),
566 zed_version: None,
567 state: None,
568 }
569 }
570
571 #[test]
572 fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
573 let example = example_with_previous_prediction();
574 let actual_output = indoc::indoc! {"
575 After reviewing the feedback, the previous prediction is still correct.
576 Use `KEEP_PREVIOUS`.
577
578 ```
579 unrelated trailing code block
580 ```
581 "};
582
583 let (patch, cursor) = parse(&example, actual_output).unwrap();
584
585 assert_eq!(patch, "previous patch");
586 assert_eq!(cursor.unwrap().offset, 3);
587 }
588}