ep: Add the repaired-teacher provider

Oleksiy Syvokon created

This is a regular teacher, followed by the QA call, followed by a repair
pass if needed.

For now, this doesn't support batching (too many stages to orchestrate)

Change summary

crates/edit_prediction_cli/src/format_prompt.rs  |   4 
crates/edit_prediction_cli/src/llm_client.rs     | 119 ++++++++++++++++++
crates/edit_prediction_cli/src/main.rs           |  14 +
crates/edit_prediction_cli/src/predict.rs        |  98 ++++++++++++++
crates/edit_prediction_cli/src/prompts/repair.md |   7 
crates/edit_prediction_cli/src/qa.rs             |  97 --------------
crates/edit_prediction_cli/src/repair.rs         | 102 +--------------
7 files changed, 246 insertions(+), 195 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -48,7 +48,9 @@ pub async fn run_format_prompt(
     let snapshot = cx.background_spawn(snapshot_fut).await;
 
     match args.provider {
-        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
+        PredictionProvider::Teacher(_)
+        | PredictionProvider::TeacherNonBatching(_)
+        | PredictionProvider::RepairedTeacher(_) => {
             step_progress.set_substatus("formatting teacher prompt");
 
             let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(

crates/edit_prediction_cli/src/llm_client.rs 🔗

@@ -0,0 +1,119 @@
+//! Shared LLM client abstraction for Anthropic and OpenAI.
+//!
+//! This module provides a unified interface for making LLM requests,
+//! supporting both synchronous and batch modes.
+
+use crate::BatchProvider;
+use crate::anthropic_client::AnthropicClient;
+use crate::openai_client::OpenAiClient;
+use crate::paths::LLM_CACHE_DB;
+use anyhow::Result;
+
+/// A unified LLM client that wraps either Anthropic or OpenAI.
+pub enum LlmClient {
+    Anthropic(AnthropicClient),
+    OpenAi(OpenAiClient),
+}
+
+impl LlmClient {
+    /// Create a new LLM client for the given backend.
+    ///
+    /// If `batched` is true, requests will be queued for batch processing.
+    /// Otherwise, requests are made synchronously.
+    pub fn new(backend: BatchProvider, batched: bool) -> Result<Self> {
+        match backend {
+            BatchProvider::Anthropic => {
+                if batched {
+                    Ok(LlmClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?))
+                } else {
+                    Ok(LlmClient::Anthropic(AnthropicClient::plain()?))
+                }
+            }
+            BatchProvider::Openai => {
+                if batched {
+                    Ok(LlmClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?))
+                } else {
+                    Ok(LlmClient::OpenAi(OpenAiClient::plain()?))
+                }
+            }
+        }
+    }
+
+    /// Generate a response from the LLM.
+    ///
+    /// Returns `Ok(None)` if the request was queued for batch processing
+    /// and results are not yet available.
+    pub async fn generate(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        prompt: &str,
+    ) -> Result<Option<String>> {
+        match self {
+            LlmClient::Anthropic(client) => {
+                let messages = vec![anthropic::Message {
+                    role: anthropic::Role::User,
+                    content: vec![anthropic::RequestContent::Text {
+                        text: prompt.to_string(),
+                        cache_control: None,
+                    }],
+                }];
+                let response = client.generate(model, max_tokens, messages, None).await?;
+                Ok(response.map(|r| {
+                    r.content
+                        .iter()
+                        .filter_map(|c| match c {
+                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join("")
+                }))
+            }
+            LlmClient::OpenAi(client) => {
+                let messages = vec![open_ai::RequestMessage::User {
+                    content: open_ai::MessageContent::Plain(prompt.to_string()),
+                }];
+                let response = client.generate(model, max_tokens, messages, None).await?;
+                Ok(response.map(|r| {
+                    r.choices
+                        .into_iter()
+                        .filter_map(|choice| match choice.message {
+                            open_ai::RequestMessage::Assistant { content, .. } => {
+                                content.map(|c| match c {
+                                    open_ai::MessageContent::Plain(text) => text,
+                                    open_ai::MessageContent::Multipart(parts) => parts
+                                        .into_iter()
+                                        .filter_map(|p| match p {
+                                            open_ai::MessagePart::Text { text } => Some(text),
+                                            _ => None,
+                                        })
+                                        .collect::<Vec<_>>()
+                                        .join(""),
+                                })
+                            }
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join("")
+                }))
+            }
+        }
+    }
+
+    /// Sync pending batches - upload queued requests and download completed results.
+    pub async fn sync_batches(&self) -> Result<()> {
+        match self {
+            LlmClient::Anthropic(client) => client.sync_batches().await,
+            LlmClient::OpenAi(client) => client.sync_batches().await,
+        }
+    }
+}
+
+/// Get the model name for a given backend.
+pub fn model_for_backend(backend: BatchProvider) -> &'static str {
+    match backend {
+        BatchProvider::Anthropic => "claude-sonnet-4-5",
+        BatchProvider::Openai => "gpt-5.2",
+    }
+}

crates/edit_prediction_cli/src/main.rs 🔗

@@ -5,6 +5,7 @@ mod filter_languages;
 mod format_prompt;
 mod git;
 mod headless;
+mod llm_client;
 mod load_project;
 mod metrics;
 mod openai_client;
@@ -291,6 +292,7 @@ enum PredictionProvider {
     Zeta2(ZetaVersion),
     Teacher(TeacherBackend),
     TeacherNonBatching(TeacherBackend),
+    RepairedTeacher(TeacherBackend),
     Repair,
 }
 
@@ -311,6 +313,9 @@ impl std::fmt::Display for PredictionProvider {
             PredictionProvider::TeacherNonBatching(backend) => {
                 write!(f, "teacher-non-batching:{backend}")
             }
+            PredictionProvider::RepairedTeacher(backend) => {
+                write!(f, "repaired-teacher:{backend}")
+            }
             PredictionProvider::Repair => write!(f, "repair"),
         }
     }
@@ -345,10 +350,17 @@ impl std::str::FromStr for PredictionProvider {
                     .unwrap_or(TeacherBackend::Sonnet45);
                 Ok(PredictionProvider::TeacherNonBatching(backend))
             }
+            "repaired-teacher" | "repaired_teacher" | "repairedteacher" => {
+                let backend = arg
+                    .map(|a| a.parse())
+                    .transpose()?
+                    .unwrap_or(TeacherBackend::Sonnet45);
+                Ok(PredictionProvider::RepairedTeacher(backend))
+            }
             "repair" => Ok(PredictionProvider::Repair),
             _ => {
                 anyhow::bail!(
-                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
+                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repaired-teacher, repair\n\
                  For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
                  For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
                  Available zeta versions:\n{}",

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -1,13 +1,16 @@
 use crate::{
-    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
+    BatchProvider, FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
     anthropic_client::AnthropicClient,
     example::{Example, ExamplePrediction, ExamplePrompt},
     format_prompt::{TeacherPrompt, run_format_prompt},
     headless::EpAppState,
+    llm_client::{LlmClient, model_for_backend},
     load_project::run_load_project,
     openai_client::OpenAiClient,
     paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
     progress::{ExampleProgress, InfoStyle, Step},
+    qa,
+    repair::{build_repair_prompt, needs_repair, parse_repair_response},
     retrieve_context::run_context_retrieval,
 };
 use anyhow::Context as _;
@@ -72,6 +75,21 @@ pub async fn run_prediction(
         return predict_teacher(example, backend, batched, repetition_count).await;
     }
 
+    if let PredictionProvider::RepairedTeacher(backend) = provider {
+        let _step_progress = example_progress.start(Step::Predict);
+
+        run_format_prompt(
+            example,
+            &FormatPromptArgs { provider },
+            app_state.clone(),
+            example_progress,
+            cx,
+        )
+        .await?;
+
+        return predict_repaired_teacher(example, backend, repetition_count).await;
+    }
+
     run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 
     let step_progress = example_progress.start(Step::Predict);
@@ -110,6 +128,7 @@ pub async fn run_prediction(
             PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
             PredictionProvider::Teacher(..)
             | PredictionProvider::TeacherNonBatching(..)
+            | PredictionProvider::RepairedTeacher(..)
             | PredictionProvider::Repair => {
                 unreachable!()
             }
@@ -407,6 +426,83 @@ async fn predict_openai(
     Ok(())
 }
 
+/// Default confidence threshold for repair
+const DEFAULT_REPAIR_CONFIDENCE_THRESHOLD: u8 = 3;
+
+/// Predict using teacher model, then run QA evaluation, and optionally repair
+/// if QA indicates issues (reverts_edits=true or low confidence).
+///
+/// This is a non-batched flow that processes each step synchronously.
+async fn predict_repaired_teacher(
+    example: &mut Example,
+    backend: TeacherBackend,
+    repetition_count: usize,
+) -> anyhow::Result<()> {
+    // Step 1: Run teacher prediction (non-batched for immediate results)
+    predict_teacher(example, backend, false, repetition_count).await?;
+
+    // Only proceed with QA/repair for the first prediction
+    let Some(prediction) = example.predictions.first() else {
+        return Ok(());
+    };
+
+    // Skip QA if no actual patch was generated
+    if prediction.actual_patch.is_none() {
+        return Ok(());
+    }
+
+    // Step 2: Run QA evaluation
+    let batch_provider = match backend {
+        TeacherBackend::Sonnet45 => BatchProvider::Anthropic,
+        TeacherBackend::Gpt52 => BatchProvider::Openai,
+    };
+    let qa_client = LlmClient::new(batch_provider, false)?;
+    let qa_model = model_for_backend(batch_provider);
+
+    let qa_result = if let Some(qa_prompt) = qa::build_prompt(example) {
+        match qa_client.generate(qa_model, 1024, &qa_prompt).await? {
+            Some(response_text) => Some(qa::parse_response(&response_text)),
+            None => None,
+        }
+    } else {
+        None
+    };
+
+    // Store QA result
+    example.qa = vec![qa_result.clone()];
+
+    // Step 3: Check if repair is needed and run repair if so
+    if needs_repair(example, DEFAULT_REPAIR_CONFIDENCE_THRESHOLD) {
+        let repair_client = LlmClient::new(batch_provider, false)?;
+
+        if let Some(repair_prompt) = build_repair_prompt(example) {
+            if let Some(response_text) = repair_client
+                .generate(qa_model, 16384, &repair_prompt)
+                .await?
+            {
+                match parse_repair_response(example, &response_text) {
+                    Ok(mut repaired_prediction) => {
+                        // Mark the prediction as coming from repaired-teacher
+                        repaired_prediction.provider = PredictionProvider::RepairedTeacher(backend);
+                        example.predictions.push(repaired_prediction);
+                    }
+                    Err(e) => {
+                        // Add error prediction if parsing failed
+                        example.predictions.push(ExamplePrediction {
+                            actual_patch: None,
+                            actual_output: response_text,
+                            error: Some(format!("Failed to parse repair response: {}", e)),
+                            provider: PredictionProvider::RepairedTeacher(backend),
+                        });
+                    }
+                }
+            }
+        }
+    }
+
+    Ok(())
+}
+
 pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
     match provider {
         Some(PredictionProvider::Teacher(backend)) => match backend {

crates/edit_prediction_cli/src/prompts/repair.md 🔗

@@ -18,6 +18,7 @@ A previous model generated a prediction that was judged to have issues. Your job
 - Keep existing formatting unless it's absolutely necessary
 - Don't write a lot of code if you're not sure what to do
 - Do not delete or remove text that was just added in the edit history. If a recent edit introduces incomplete or incorrect code, finish or fix it in place, or simply do nothing rather than removing it. Only remove a recent edit if the history explicitly shows the user undoing it themselves.
+- When uncertain, predict only the minimal, high-confidence portion of the edit. Prefer a small, correct prediction over a large, speculative one
 
 # Input Format
 
@@ -34,9 +35,9 @@ You will be provided with:
 # Output Format
 
 - Briefly explain what was wrong with the previous prediction and how you'll improve it.
-- Output the entire editable region, applying the edits that you predict the user will make next.
-- If you're unsure about some portion of the next edit, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in.
+- Output a markdown codeblock containing **only** the editable region with your predicted edits applied. The codeblock must start with `<|editable_region_start|>` and end with `<|editable_region_end|>`. Do not include any content before or after these tags.
 - Wrap the edited code in a codeblock with exactly five backticks.
+- If you're unsure about some portion of the next edit, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in.
 
 # 1. User Edits History
 
@@ -68,4 +69,4 @@ The previous model generated the following edit (in word-diff format):
 
 # Your Improved Prediction
 
-Based on the feedback above, generate an improved prediction. Address the issues identified in the quality assessment.
+Based on the feedback above, generate an improved prediction. Address the issues identified in the quality assessment. Prefer a small, correct prediction over a large, speculative one.

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -4,11 +4,9 @@
 //! Caching is handled by the underlying client.
 
 use crate::BatchProvider;
-use crate::anthropic_client::AnthropicClient;
 use crate::example::Example;
 use crate::format_prompt::extract_cursor_excerpt_from_example;
-use crate::openai_client::OpenAiClient;
-use crate::paths::LLM_CACHE_DB;
+use crate::llm_client::{LlmClient, model_for_backend};
 use crate::word_diff::unified_to_word_diff;
 use anyhow::Result;
 use serde::{Deserialize, Serialize};
@@ -33,13 +31,6 @@ pub struct QaArgs {
     pub backend: BatchProvider,
 }
 
-fn model_for_backend(backend: BatchProvider) -> &'static str {
-    match backend {
-        BatchProvider::Anthropic => "claude-sonnet-4-5",
-        BatchProvider::Openai => "gpt-5.2",
-    }
-}
-
 /// Result of QA evaluation for a single prediction.
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct QaResult {
@@ -120,7 +111,7 @@ fn extract_codeblock(response: &str) -> Option<String> {
 }
 
 /// Parse the LLM response into a QaResult.
-fn parse_response(response_text: &str) -> QaResult {
+pub(crate) fn parse_response(response_text: &str) -> QaResult {
     let codeblock = extract_codeblock(response_text);
 
     // Try parsing codeblock first, then fall back to raw response
@@ -156,73 +147,6 @@ fn parse_response(response_text: &str) -> QaResult {
     }
 }
 
-enum QaClient {
-    Anthropic(AnthropicClient),
-    OpenAi(OpenAiClient),
-}
-
-impl QaClient {
-    async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
-        match self {
-            QaClient::Anthropic(client) => {
-                let messages = vec![anthropic::Message {
-                    role: anthropic::Role::User,
-                    content: vec![anthropic::RequestContent::Text {
-                        text: prompt.to_string(),
-                        cache_control: None,
-                    }],
-                }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
-                Ok(response.map(|r| {
-                    r.content
-                        .iter()
-                        .filter_map(|c| match c {
-                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
-                            _ => None,
-                        })
-                        .collect::<Vec<_>>()
-                        .join("")
-                }))
-            }
-            QaClient::OpenAi(client) => {
-                let messages = vec![open_ai::RequestMessage::User {
-                    content: open_ai::MessageContent::Plain(prompt.to_string()),
-                }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
-                Ok(response.map(|r| {
-                    r.choices
-                        .into_iter()
-                        .filter_map(|choice| match choice.message {
-                            open_ai::RequestMessage::Assistant { content, .. } => {
-                                content.map(|c| match c {
-                                    open_ai::MessageContent::Plain(text) => text,
-                                    open_ai::MessageContent::Multipart(parts) => parts
-                                        .into_iter()
-                                        .filter_map(|p| match p {
-                                            open_ai::MessagePart::Text { text } => Some(text),
-                                            _ => None,
-                                        })
-                                        .collect::<Vec<_>>()
-                                        .join(""),
-                                })
-                            }
-                            _ => None,
-                        })
-                        .collect::<Vec<_>>()
-                        .join("")
-                }))
-            }
-        }
-    }
-
-    async fn sync_batches(&self) -> Result<()> {
-        match self {
-            QaClient::Anthropic(client) => client.sync_batches().await,
-            QaClient::OpenAi(client) => client.sync_batches().await,
-        }
-    }
-}
-
 /// Run the QA evaluation on a set of examples.
 pub async fn run_qa(
     examples: &mut [Example],
@@ -230,22 +154,7 @@ pub async fn run_qa(
     output_path: Option<&PathBuf>,
 ) -> Result<()> {
     let model = model_for_backend(args.backend);
-    let client = match args.backend {
-        BatchProvider::Anthropic => {
-            if args.no_batch {
-                QaClient::Anthropic(AnthropicClient::plain()?)
-            } else {
-                QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
-            }
-        }
-        BatchProvider::Openai => {
-            if args.no_batch {
-                QaClient::OpenAi(OpenAiClient::plain()?)
-            } else {
-                QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
-            }
-        }
-    };
+    let client = LlmClient::new(args.backend, !args.no_batch)?;
 
     eprintln!(
         "Using model: {}, backend: {:?}, batching: {}",

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -6,11 +6,9 @@
 
 use crate::BatchProvider;
 use crate::PredictionProvider;
-use crate::anthropic_client::AnthropicClient;
 use crate::example::{Example, ExamplePrediction};
 use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
-use crate::openai_client::OpenAiClient;
-use crate::paths::LLM_CACHE_DB;
+use crate::llm_client::{LlmClient, model_for_backend};
 use crate::word_diff::unified_to_word_diff;
 use anyhow::Result;
 use std::io::{BufWriter, Write};
@@ -30,7 +28,7 @@ pub struct RepairArgs {
     pub wait: bool,
 
     /// Confidence threshold: repair predictions with confidence <= this value (1-5)
-    #[clap(long, default_value = "2")]
+    #[clap(long, default_value = "3")]
     pub confidence_threshold: u8,
 
     /// Which LLM provider to use (anthropic or openai)
@@ -38,13 +36,6 @@ pub struct RepairArgs {
     pub backend: BatchProvider,
 }
 
-fn model_for_backend(backend: BatchProvider) -> &'static str {
-    match backend {
-        BatchProvider::Anthropic => "claude-sonnet-4-5",
-        BatchProvider::Openai => "gpt-5.2",
-    }
-}
-
 /// Build the repair prompt for an example that needs improvement.
 ///
 /// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs).
@@ -125,7 +116,10 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
 }
 
 /// Parse the repair response into a prediction.
-fn parse_repair_response(example: &Example, response_text: &str) -> Result<ExamplePrediction> {
+pub(crate) fn parse_repair_response(
+    example: &Example,
+    response_text: &str,
+) -> Result<ExamplePrediction> {
     let actual_patch = TeacherPrompt::parse(example, response_text)?;
 
     Ok(ExamplePrediction {
@@ -136,73 +130,6 @@ fn parse_repair_response(example: &Example, response_text: &str) -> Result<Examp
     })
 }
 
-enum RepairClient {
-    Anthropic(AnthropicClient),
-    OpenAi(OpenAiClient),
-}
-
-impl RepairClient {
-    async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
-        match self {
-            RepairClient::Anthropic(client) => {
-                let messages = vec![anthropic::Message {
-                    role: anthropic::Role::User,
-                    content: vec![anthropic::RequestContent::Text {
-                        text: prompt.to_string(),
-                        cache_control: None,
-                    }],
-                }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
-                Ok(response.map(|r| {
-                    r.content
-                        .iter()
-                        .filter_map(|c| match c {
-                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
-                            _ => None,
-                        })
-                        .collect::<Vec<_>>()
-                        .join("")
-                }))
-            }
-            RepairClient::OpenAi(client) => {
-                let messages = vec![open_ai::RequestMessage::User {
-                    content: open_ai::MessageContent::Plain(prompt.to_string()),
-                }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
-                Ok(response.map(|r| {
-                    r.choices
-                        .into_iter()
-                        .filter_map(|choice| match choice.message {
-                            open_ai::RequestMessage::Assistant { content, .. } => {
-                                content.map(|c| match c {
-                                    open_ai::MessageContent::Plain(text) => text,
-                                    open_ai::MessageContent::Multipart(parts) => parts
-                                        .into_iter()
-                                        .filter_map(|p| match p {
-                                            open_ai::MessagePart::Text { text } => Some(text),
-                                            _ => None,
-                                        })
-                                        .collect::<Vec<_>>()
-                                        .join(""),
-                                })
-                            }
-                            _ => None,
-                        })
-                        .collect::<Vec<_>>()
-                        .join("")
-                }))
-            }
-        }
-    }
-
-    async fn sync_batches(&self) -> Result<()> {
-        match self {
-            RepairClient::Anthropic(client) => client.sync_batches().await,
-            RepairClient::OpenAi(client) => client.sync_batches().await,
-        }
-    }
-}
-
 /// Run the repair process on a set of examples.
 pub async fn run_repair(
     examples: &mut [Example],
@@ -210,22 +137,7 @@ pub async fn run_repair(
     output_path: Option<&PathBuf>,
 ) -> Result<()> {
     let model = model_for_backend(args.backend);
-    let client = match args.backend {
-        BatchProvider::Anthropic => {
-            if args.no_batch {
-                RepairClient::Anthropic(AnthropicClient::plain()?)
-            } else {
-                RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
-            }
-        }
-        BatchProvider::Openai => {
-            if args.no_batch {
-                RepairClient::OpenAi(OpenAiClient::plain()?)
-            } else {
-                RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
-            }
-        }
-    };
+    let client = LlmClient::new(args.backend, !args.no_batch)?;
 
     eprintln!(
         "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}",