From c1f0df1fec3eb57b4701d420921d5fb0d3deb293 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Wed, 28 Jan 2026 16:30:18 +0200 Subject: [PATCH] ep: Add the repaired-teacher provider This is a regular teacher, followed by the QA call, followed by a repair pass if needed. For now, this doesn't support batching (too many stages to orchestrate) --- .../edit_prediction_cli/src/format_prompt.rs | 4 +- crates/edit_prediction_cli/src/llm_client.rs | 119 ++++++++++++++++++ crates/edit_prediction_cli/src/main.rs | 14 ++- crates/edit_prediction_cli/src/predict.rs | 98 ++++++++++++++- .../edit_prediction_cli/src/prompts/repair.md | 7 +- crates/edit_prediction_cli/src/qa.rs | 97 +------------- crates/edit_prediction_cli/src/repair.rs | 102 ++------------- 7 files changed, 246 insertions(+), 195 deletions(-) create mode 100644 crates/edit_prediction_cli/src/llm_client.rs diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 9813615540f1b6d58dfabf558fd18526e8d38d1d..8954a65d2efb115bd4a471a6fd071faf747f02f2 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -48,7 +48,9 @@ pub async fn run_format_prompt( let snapshot = cx.background_spawn(snapshot_fut).await; match args.provider { - PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => { + PredictionProvider::Teacher(_) + | PredictionProvider::TeacherNonBatching(_) + | PredictionProvider::RepairedTeacher(_) => { step_progress.set_substatus("formatting teacher prompt"); let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( diff --git a/crates/edit_prediction_cli/src/llm_client.rs b/crates/edit_prediction_cli/src/llm_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..14d616eab0b2ca38980e6a5aff17ae2f4a5fe9fe --- /dev/null +++ b/crates/edit_prediction_cli/src/llm_client.rs @@ -0,0 +1,119 @@ +//! Shared LLM client abstraction for Anthropic and OpenAI. +//! +//! This module provides a unified interface for making LLM requests, +//! supporting both synchronous and batch modes. + +use crate::BatchProvider; +use crate::anthropic_client::AnthropicClient; +use crate::openai_client::OpenAiClient; +use crate::paths::LLM_CACHE_DB; +use anyhow::Result; + +/// A unified LLM client that wraps either Anthropic or OpenAI. +pub enum LlmClient { + Anthropic(AnthropicClient), + OpenAi(OpenAiClient), +} + +impl LlmClient { + /// Create a new LLM client for the given backend. + /// + /// If `batched` is true, requests will be queued for batch processing. + /// Otherwise, requests are made synchronously. + pub fn new(backend: BatchProvider, batched: bool) -> Result { + match backend { + BatchProvider::Anthropic => { + if batched { + Ok(LlmClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)) + } else { + Ok(LlmClient::Anthropic(AnthropicClient::plain()?)) + } + } + BatchProvider::Openai => { + if batched { + Ok(LlmClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)) + } else { + Ok(LlmClient::OpenAi(OpenAiClient::plain()?)) + } + } + } + } + + /// Generate a response from the LLM. + /// + /// Returns `Ok(None)` if the request was queued for batch processing + /// and results are not yet available. + pub async fn generate( + &self, + model: &str, + max_tokens: u64, + prompt: &str, + ) -> Result> { + match self { + LlmClient::Anthropic(client) => { + let messages = vec![anthropic::Message { + role: anthropic::Role::User, + content: vec![anthropic::RequestContent::Text { + text: prompt.to_string(), + cache_control: None, + }], + }]; + let response = client.generate(model, max_tokens, messages, None).await?; + Ok(response.map(|r| { + r.content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") + })) + } + LlmClient::OpenAi(client) => { + let messages = vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt.to_string()), + }]; + let response = client.generate(model, max_tokens, messages, None).await?; + Ok(response.map(|r| { + r.choices + .into_iter() + .filter_map(|choice| match choice.message { + open_ai::RequestMessage::Assistant { content, .. } => { + content.map(|c| match c { + open_ai::MessageContent::Plain(text) => text, + open_ai::MessageContent::Multipart(parts) => parts + .into_iter() + .filter_map(|p| match p { + open_ai::MessagePart::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join(""), + }) + } + _ => None, + }) + .collect::>() + .join("") + })) + } + } + } + + /// Sync pending batches - upload queued requests and download completed results. + pub async fn sync_batches(&self) -> Result<()> { + match self { + LlmClient::Anthropic(client) => client.sync_batches().await, + LlmClient::OpenAi(client) => client.sync_batches().await, + } + } +} + +/// Get the model name for a given backend. +pub fn model_for_backend(backend: BatchProvider) -> &'static str { + match backend { + BatchProvider::Anthropic => "claude-sonnet-4-5", + BatchProvider::Openai => "gpt-5.2", + } +} diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 6790491c69ae2888ca78bbd05db76ccacb92f974..bbc44aa3153460e9d4bd097c7126450ad88f68a0 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -5,6 +5,7 @@ mod filter_languages; mod format_prompt; mod git; mod headless; +mod llm_client; mod load_project; mod metrics; mod openai_client; @@ -291,6 +292,7 @@ enum PredictionProvider { Zeta2(ZetaVersion), Teacher(TeacherBackend), TeacherNonBatching(TeacherBackend), + RepairedTeacher(TeacherBackend), Repair, } @@ -311,6 +313,9 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::TeacherNonBatching(backend) => { write!(f, "teacher-non-batching:{backend}") } + PredictionProvider::RepairedTeacher(backend) => { + write!(f, "repaired-teacher:{backend}") + } PredictionProvider::Repair => write!(f, "repair"), } } @@ -345,10 +350,17 @@ impl std::str::FromStr for PredictionProvider { .unwrap_or(TeacherBackend::Sonnet45); Ok(PredictionProvider::TeacherNonBatching(backend)) } + "repaired-teacher" | "repaired_teacher" | "repairedteacher" => { + let backend = arg + .map(|a| a.parse()) + .transpose()? + .unwrap_or(TeacherBackend::Sonnet45); + Ok(PredictionProvider::RepairedTeacher(backend)) + } "repair" => Ok(PredictionProvider::Repair), _ => { anyhow::bail!( - "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher:, teacher-non-batching, repair\n\ + "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher:, teacher-non-batching, repaired-teacher, repair\n\ For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\ Available zeta versions:\n{}", diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 19c2591b4fe3a1fdede82269da37af170ea4d2d7..9567a6fca0b865f3a039a28c8ecedf7181831aa7 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -1,13 +1,16 @@ use crate::{ - FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend, + BatchProvider, FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend, anthropic_client::AnthropicClient, example::{Example, ExamplePrediction, ExamplePrompt}, format_prompt::{TeacherPrompt, run_format_prompt}, headless::EpAppState, + llm_client::{LlmClient, model_for_backend}, load_project::run_load_project, openai_client::OpenAiClient, paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, progress::{ExampleProgress, InfoStyle, Step}, + qa, + repair::{build_repair_prompt, needs_repair, parse_repair_response}, retrieve_context::run_context_retrieval, }; use anyhow::Context as _; @@ -72,6 +75,21 @@ pub async fn run_prediction( return predict_teacher(example, backend, batched, repetition_count).await; } + if let PredictionProvider::RepairedTeacher(backend) = provider { + let _step_progress = example_progress.start(Step::Predict); + + run_format_prompt( + example, + &FormatPromptArgs { provider }, + app_state.clone(), + example_progress, + cx, + ) + .await?; + + return predict_repaired_teacher(example, backend, repetition_count).await; + } + run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; let step_progress = example_progress.start(Step::Predict); @@ -110,6 +128,7 @@ pub async fn run_prediction( PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) + | PredictionProvider::RepairedTeacher(..) | PredictionProvider::Repair => { unreachable!() } @@ -407,6 +426,83 @@ async fn predict_openai( Ok(()) } +/// Default confidence threshold for repair +const DEFAULT_REPAIR_CONFIDENCE_THRESHOLD: u8 = 3; + +/// Predict using teacher model, then run QA evaluation, and optionally repair +/// if QA indicates issues (reverts_edits=true or low confidence). +/// +/// This is a non-batched flow that processes each step synchronously. +async fn predict_repaired_teacher( + example: &mut Example, + backend: TeacherBackend, + repetition_count: usize, +) -> anyhow::Result<()> { + // Step 1: Run teacher prediction (non-batched for immediate results) + predict_teacher(example, backend, false, repetition_count).await?; + + // Only proceed with QA/repair for the first prediction + let Some(prediction) = example.predictions.first() else { + return Ok(()); + }; + + // Skip QA if no actual patch was generated + if prediction.actual_patch.is_none() { + return Ok(()); + } + + // Step 2: Run QA evaluation + let batch_provider = match backend { + TeacherBackend::Sonnet45 => BatchProvider::Anthropic, + TeacherBackend::Gpt52 => BatchProvider::Openai, + }; + let qa_client = LlmClient::new(batch_provider, false)?; + let qa_model = model_for_backend(batch_provider); + + let qa_result = if let Some(qa_prompt) = qa::build_prompt(example) { + match qa_client.generate(qa_model, 1024, &qa_prompt).await? { + Some(response_text) => Some(qa::parse_response(&response_text)), + None => None, + } + } else { + None + }; + + // Store QA result + example.qa = vec![qa_result.clone()]; + + // Step 3: Check if repair is needed and run repair if so + if needs_repair(example, DEFAULT_REPAIR_CONFIDENCE_THRESHOLD) { + let repair_client = LlmClient::new(batch_provider, false)?; + + if let Some(repair_prompt) = build_repair_prompt(example) { + if let Some(response_text) = repair_client + .generate(qa_model, 16384, &repair_prompt) + .await? + { + match parse_repair_response(example, &response_text) { + Ok(mut repaired_prediction) => { + // Mark the prediction as coming from repaired-teacher + repaired_prediction.provider = PredictionProvider::RepairedTeacher(backend); + example.predictions.push(repaired_prediction); + } + Err(e) => { + // Add error prediction if parsing failed + example.predictions.push(ExamplePrediction { + actual_patch: None, + actual_output: response_text, + error: Some(format!("Failed to parse repair response: {}", e)), + provider: PredictionProvider::RepairedTeacher(backend), + }); + } + } + } + } + } + + Ok(()) +} + pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> { match provider { Some(PredictionProvider::Teacher(backend)) => match backend { diff --git a/crates/edit_prediction_cli/src/prompts/repair.md b/crates/edit_prediction_cli/src/prompts/repair.md index 3fb32cc5f5abdbf3e03a8c08f125096da498d7da..df293c4c6fade8fcee1bac83ce15f4c0538f0726 100644 --- a/crates/edit_prediction_cli/src/prompts/repair.md +++ b/crates/edit_prediction_cli/src/prompts/repair.md @@ -18,6 +18,7 @@ A previous model generated a prediction that was judged to have issues. Your job - Keep existing formatting unless it's absolutely necessary - Don't write a lot of code if you're not sure what to do - Do not delete or remove text that was just added in the edit history. If a recent edit introduces incomplete or incorrect code, finish or fix it in place, or simply do nothing rather than removing it. Only remove a recent edit if the history explicitly shows the user undoing it themselves. +- When uncertain, predict only the minimal, high-confidence portion of the edit. Prefer a small, correct prediction over a large, speculative one # Input Format @@ -34,9 +35,9 @@ You will be provided with: # Output Format - Briefly explain what was wrong with the previous prediction and how you'll improve it. -- Output the entire editable region, applying the edits that you predict the user will make next. -- If you're unsure about some portion of the next edit, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in. +- Output a markdown codeblock containing **only** the editable region with your predicted edits applied. The codeblock must start with `<|editable_region_start|>` and end with `<|editable_region_end|>`. Do not include any content before or after these tags. - Wrap the edited code in a codeblock with exactly five backticks. +- If you're unsure about some portion of the next edit, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in. # 1. User Edits History @@ -68,4 +69,4 @@ The previous model generated the following edit (in word-diff format): # Your Improved Prediction -Based on the feedback above, generate an improved prediction. Address the issues identified in the quality assessment. +Based on the feedback above, generate an improved prediction. Address the issues identified in the quality assessment. Prefer a small, correct prediction over a large, speculative one. diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs index 28a592c2b875303d59087e3fe5e0e7d176ee74c2..c171684a36df0d520fc6c42b4f98c55b1d107171 100644 --- a/crates/edit_prediction_cli/src/qa.rs +++ b/crates/edit_prediction_cli/src/qa.rs @@ -4,11 +4,9 @@ //! Caching is handled by the underlying client. use crate::BatchProvider; -use crate::anthropic_client::AnthropicClient; use crate::example::Example; use crate::format_prompt::extract_cursor_excerpt_from_example; -use crate::openai_client::OpenAiClient; -use crate::paths::LLM_CACHE_DB; +use crate::llm_client::{LlmClient, model_for_backend}; use crate::word_diff::unified_to_word_diff; use anyhow::Result; use serde::{Deserialize, Serialize}; @@ -33,13 +31,6 @@ pub struct QaArgs { pub backend: BatchProvider, } -fn model_for_backend(backend: BatchProvider) -> &'static str { - match backend { - BatchProvider::Anthropic => "claude-sonnet-4-5", - BatchProvider::Openai => "gpt-5.2", - } -} - /// Result of QA evaluation for a single prediction. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QaResult { @@ -120,7 +111,7 @@ fn extract_codeblock(response: &str) -> Option { } /// Parse the LLM response into a QaResult. -fn parse_response(response_text: &str) -> QaResult { +pub(crate) fn parse_response(response_text: &str) -> QaResult { let codeblock = extract_codeblock(response_text); // Try parsing codeblock first, then fall back to raw response @@ -156,73 +147,6 @@ fn parse_response(response_text: &str) -> QaResult { } } -enum QaClient { - Anthropic(AnthropicClient), - OpenAi(OpenAiClient), -} - -impl QaClient { - async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result> { - match self { - QaClient::Anthropic(client) => { - let messages = vec![anthropic::Message { - role: anthropic::Role::User, - content: vec![anthropic::RequestContent::Text { - text: prompt.to_string(), - cache_control: None, - }], - }]; - let response = client.generate(model, max_tokens, messages, None).await?; - Ok(response.map(|r| { - r.content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join("") - })) - } - QaClient::OpenAi(client) => { - let messages = vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt.to_string()), - }]; - let response = client.generate(model, max_tokens, messages, None).await?; - Ok(response.map(|r| { - r.choices - .into_iter() - .filter_map(|choice| match choice.message { - open_ai::RequestMessage::Assistant { content, .. } => { - content.map(|c| match c { - open_ai::MessageContent::Plain(text) => text, - open_ai::MessageContent::Multipart(parts) => parts - .into_iter() - .filter_map(|p| match p { - open_ai::MessagePart::Text { text } => Some(text), - _ => None, - }) - .collect::>() - .join(""), - }) - } - _ => None, - }) - .collect::>() - .join("") - })) - } - } - } - - async fn sync_batches(&self) -> Result<()> { - match self { - QaClient::Anthropic(client) => client.sync_batches().await, - QaClient::OpenAi(client) => client.sync_batches().await, - } - } -} - /// Run the QA evaluation on a set of examples. pub async fn run_qa( examples: &mut [Example], @@ -230,22 +154,7 @@ pub async fn run_qa( output_path: Option<&PathBuf>, ) -> Result<()> { let model = model_for_backend(args.backend); - let client = match args.backend { - BatchProvider::Anthropic => { - if args.no_batch { - QaClient::Anthropic(AnthropicClient::plain()?) - } else { - QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?) - } - } - BatchProvider::Openai => { - if args.no_batch { - QaClient::OpenAi(OpenAiClient::plain()?) - } else { - QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?) - } - } - }; + let client = LlmClient::new(args.backend, !args.no_batch)?; eprintln!( "Using model: {}, backend: {:?}, batching: {}", diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index e78420d2d5f4bface31d1bd85e21165b38226f76..134ffa2ee812c75a72c00923000e203cd2bbabf4 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -6,11 +6,9 @@ use crate::BatchProvider; use crate::PredictionProvider; -use crate::anthropic_client::AnthropicClient; use crate::example::{Example, ExamplePrediction}; use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example}; -use crate::openai_client::OpenAiClient; -use crate::paths::LLM_CACHE_DB; +use crate::llm_client::{LlmClient, model_for_backend}; use crate::word_diff::unified_to_word_diff; use anyhow::Result; use std::io::{BufWriter, Write}; @@ -30,7 +28,7 @@ pub struct RepairArgs { pub wait: bool, /// Confidence threshold: repair predictions with confidence <= this value (1-5) - #[clap(long, default_value = "2")] + #[clap(long, default_value = "3")] pub confidence_threshold: u8, /// Which LLM provider to use (anthropic or openai) @@ -38,13 +36,6 @@ pub struct RepairArgs { pub backend: BatchProvider, } -fn model_for_backend(backend: BatchProvider) -> &'static str { - match backend { - BatchProvider::Anthropic => "claude-sonnet-4-5", - BatchProvider::Openai => "gpt-5.2", - } -} - /// Build the repair prompt for an example that needs improvement. /// /// Returns None if the example doesn't have the required data (predictions, qa, prompt_inputs). @@ -125,7 +116,10 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool { } /// Parse the repair response into a prediction. -fn parse_repair_response(example: &Example, response_text: &str) -> Result { +pub(crate) fn parse_repair_response( + example: &Example, + response_text: &str, +) -> Result { let actual_patch = TeacherPrompt::parse(example, response_text)?; Ok(ExamplePrediction { @@ -136,73 +130,6 @@ fn parse_repair_response(example: &Example, response_text: &str) -> Result Result> { - match self { - RepairClient::Anthropic(client) => { - let messages = vec![anthropic::Message { - role: anthropic::Role::User, - content: vec![anthropic::RequestContent::Text { - text: prompt.to_string(), - cache_control: None, - }], - }]; - let response = client.generate(model, max_tokens, messages, None).await?; - Ok(response.map(|r| { - r.content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join("") - })) - } - RepairClient::OpenAi(client) => { - let messages = vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt.to_string()), - }]; - let response = client.generate(model, max_tokens, messages, None).await?; - Ok(response.map(|r| { - r.choices - .into_iter() - .filter_map(|choice| match choice.message { - open_ai::RequestMessage::Assistant { content, .. } => { - content.map(|c| match c { - open_ai::MessageContent::Plain(text) => text, - open_ai::MessageContent::Multipart(parts) => parts - .into_iter() - .filter_map(|p| match p { - open_ai::MessagePart::Text { text } => Some(text), - _ => None, - }) - .collect::>() - .join(""), - }) - } - _ => None, - }) - .collect::>() - .join("") - })) - } - } - } - - async fn sync_batches(&self) -> Result<()> { - match self { - RepairClient::Anthropic(client) => client.sync_batches().await, - RepairClient::OpenAi(client) => client.sync_batches().await, - } - } -} - /// Run the repair process on a set of examples. pub async fn run_repair( examples: &mut [Example], @@ -210,22 +137,7 @@ pub async fn run_repair( output_path: Option<&PathBuf>, ) -> Result<()> { let model = model_for_backend(args.backend); - let client = match args.backend { - BatchProvider::Anthropic => { - if args.no_batch { - RepairClient::Anthropic(AnthropicClient::plain()?) - } else { - RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?) - } - } - BatchProvider::Openai => { - if args.no_batch { - RepairClient::OpenAi(OpenAiClient::plain()?) - } else { - RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?) - } - } - }; + let client = LlmClient::new(args.backend, !args.no_batch)?; eprintln!( "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}",