ep: Use rejected_output for DPO training + OpenAI support (#47697)

Oleksiy Syvokon and Zed Zippy created

Release Notes:

- N/A

---------

Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com>

Change summary

Cargo.lock                                      |   2 
crates/edit_prediction_cli/Cargo.toml           |   1 
crates/edit_prediction_cli/src/distill.rs       |  20 
crates/edit_prediction_cli/src/example.rs       |   1 
crates/edit_prediction_cli/src/format_prompt.rs |  27 
crates/edit_prediction_cli/src/main.rs          | 115 ++
crates/edit_prediction_cli/src/openai_client.rs | 664 +++++++++++++++++++
crates/edit_prediction_cli/src/predict.rs       | 131 +++
crates/edit_prediction_cli/src/qa.rs            | 182 +++--
crates/edit_prediction_cli/src/repair.rs        | 175 ++--
crates/open_ai/Cargo.toml                       |   1 
crates/open_ai/src/batches.rs                   | 332 +++++++++
crates/open_ai/src/open_ai.rs                   |  47 +
13 files changed, 1,497 insertions(+), 201 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5338,6 +5338,7 @@ dependencies = [
  "libc",
  "log",
  "node_runtime",
+ "open_ai",
  "paths",
  "pretty_assertions",
  "project",
@@ -11140,6 +11141,7 @@ dependencies = [
  "futures 0.3.31",
  "http_client",
  "log",
+ "rand 0.9.2",
  "schemars",
  "serde",
  "serde_json",

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -37,6 +37,7 @@ languages = { workspace = true, features = ["load-grammars"] }
 libc.workspace = true
 log.workspace = true
 node_runtime.workspace = true
+open_ai.workspace = true
 
 paths.workspace = true
 project.workspace = true

crates/edit_prediction_cli/src/distill.rs 🔗

@@ -1,17 +1,27 @@
 use anyhow::Result;
-use std::mem;
 
-use crate::example::Example;
+use crate::{PredictionProvider, example::Example};
 
 pub async fn run_distill(example: &mut Example) -> Result<()> {
-    let predictions = mem::take(&mut example.predictions)
+    let has_repair = example
+        .predictions
+        .iter()
+        .find(|p| p.provider == PredictionProvider::Repair);
+    let predictions = if let Some(has_repair) = has_repair {
+        vec![has_repair]
+    } else {
+        example.predictions.iter().collect()
+    };
+
+    let expected_patches = predictions
         .into_iter()
-        .filter_map(|p| p.actual_patch)
+        .filter_map(|p| p.actual_patch.clone())
         .collect();
 
-    example.spec.expected_patches = predictions;
+    example.spec.expected_patches = expected_patches;
     example.prompt = None;
     example.predictions = Vec::new();
     example.score = Vec::new();
+    example.qa = Vec::new();
     Ok(())
 }

crates/edit_prediction_cli/src/example.rs 🔗

@@ -73,6 +73,7 @@ pub struct ExamplePromptInputs {
 pub struct ExamplePrompt {
     pub input: String,
     pub expected_output: String,
+    pub rejected_output: Option<String>, // For DPO
     pub provider: PredictionProvider,
 }
 

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -48,27 +48,31 @@ pub async fn run_format_prompt(
     let snapshot = cx.background_spawn(snapshot_fut).await;
 
     match args.provider {
-        PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
+        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
             step_progress.set_substatus("formatting teacher prompt");
 
             let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
                 cursor_point,
                 &snapshot,
-                edit_prediction::zeta2::max_editable_tokens(version),
+                edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()),
                 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
             );
             let editable_range = editable_range.to_offset(&snapshot);
             let context_range = context_range.to_offset(&snapshot);
 
             let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
+            let expected_output = example
+                .spec
+                .expected_patches
+                .first()
+                .cloned()
+                .unwrap_or_default();
+            let rejected_output = example.spec.rejected_patch.clone();
+
             example.prompt = Some(ExamplePrompt {
                 input: prompt,
-                expected_output: example
-                    .spec
-                    .expected_patches
-                    .first()
-                    .cloned()
-                    .unwrap_or_default(),
+                expected_output,
+                rejected_output,
                 provider: args.provider,
             });
         }
@@ -107,9 +111,16 @@ pub async fn run_format_prompt(
                     .clone(),
                 version,
             )?;
+            let rejected_output = example
+                .spec
+                .rejected_patch
+                .as_ref()
+                .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok());
+
             example.prompt = Some(ExamplePrompt {
                 input: prompt,
                 expected_output,
+                rejected_output,
                 provider: args.provider,
             });
         }

crates/edit_prediction_cli/src/main.rs 🔗

@@ -7,6 +7,7 @@ mod git;
 mod headless;
 mod load_project;
 mod metrics;
+mod openai_client;
 mod parse_output;
 mod paths;
 mod predict;
@@ -241,14 +242,51 @@ struct EvalArgs {
     summary_json: Option<PathBuf>,
 }
 
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+pub enum TeacherBackend {
+    Sonnet45,
+    Gpt52,
+}
+
+impl std::fmt::Display for TeacherBackend {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
+            TeacherBackend::Gpt52 => write!(f, "gpt52"),
+        }
+    }
+}
+
+impl std::str::FromStr for TeacherBackend {
+    type Err = anyhow::Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        match s.to_lowercase().as_str() {
+            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
+            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
+            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
+            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
+        }
+    }
+}
+
+impl TeacherBackend {
+    pub fn model_name(&self) -> &'static str {
+        match self {
+            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
+            TeacherBackend::Gpt52 => "gpt-5.2",
+        }
+    }
+}
+
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 enum PredictionProvider {
     Sweep,
     Mercury,
     Zeta1,
     Zeta2(ZetaVersion),
-    Teacher(ZetaVersion),
-    TeacherNonBatching(ZetaVersion),
+    Teacher(TeacherBackend),
+    TeacherNonBatching(TeacherBackend),
     Repair,
 }
 
@@ -265,9 +303,9 @@ impl std::fmt::Display for PredictionProvider {
             PredictionProvider::Mercury => write!(f, "mercury"),
             PredictionProvider::Zeta1 => write!(f, "zeta1"),
             PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
-            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
-            PredictionProvider::TeacherNonBatching(version) => {
-                write!(f, "teacher-non-batching:{version}")
+            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
+            PredictionProvider::TeacherNonBatching(backend) => {
+                write!(f, "teacher-non-batching:{backend}")
             }
             PredictionProvider::Repair => write!(f, "repair"),
         }
@@ -277,28 +315,38 @@ impl std::fmt::Display for PredictionProvider {
 impl std::str::FromStr for PredictionProvider {
     type Err = anyhow::Error;
 
-    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
-        let mut version = ZetaVersion::default();
-        if let Some((first, second)) = s.split_once(':') {
-            version = ZetaVersion::parse(second)?;
-            s = first;
-        }
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 
-        let s_lower = s.to_lowercase();
-        match s_lower.as_str() {
+        let provider_lower = provider.to_lowercase();
+        match provider_lower.as_str() {
             "sweep" => Ok(PredictionProvider::Sweep),
             "mercury" => Ok(PredictionProvider::Mercury),
             "zeta1" => Ok(PredictionProvider::Zeta1),
-            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
-            "teacher" => Ok(PredictionProvider::Teacher(version)),
+            "zeta2" => {
+                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
+                Ok(PredictionProvider::Zeta2(version))
+            }
+            "teacher" => {
+                let backend = arg
+                    .map(|a| a.parse())
+                    .transpose()?
+                    .unwrap_or(TeacherBackend::Sonnet45);
+                Ok(PredictionProvider::Teacher(backend))
+            }
             "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
-                Ok(PredictionProvider::TeacherNonBatching(version))
+                let backend = arg
+                    .map(|a| a.parse())
+                    .transpose()?
+                    .unwrap_or(TeacherBackend::Sonnet45);
+                Ok(PredictionProvider::TeacherNonBatching(backend))
             }
             "repair" => Ok(PredictionProvider::Repair),
             _ => {
                 anyhow::bail!(
-                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching, repair\n\
+                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
                  For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
+                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
                  Available zeta versions:\n{}",
                     ZetaVersion::options_as_string()
                 )
@@ -347,9 +395,18 @@ struct SynthesizeArgs {
 
 #[derive(Debug, Args, Clone)]
 struct ImportBatchArgs {
-    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
+    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
     #[clap(long, required = true, num_args = 1..)]
     batch_ids: Vec<String>,
+    /// Which provider's batches to import (anthropic or openai)
+    #[clap(long, default_value = "anthropic")]
+    provider: BatchProvider,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
+enum BatchProvider {
+    Anthropic,
+    Openai,
 }
 
 impl EpArgs {
@@ -537,11 +594,23 @@ fn main() {
     match &command {
         Command::ImportBatch(import_args) => {
             smol::block_on(async {
-                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
-                    .expect("Failed to create Anthropic client");
-                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
-                    eprintln!("Error importing batches: {:?}", e);
-                    std::process::exit(1);
+                match import_args.provider {
+                    BatchProvider::Anthropic => {
+                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
+                            .expect("Failed to create Anthropic client");
+                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
+                            eprintln!("Error importing Anthropic batches: {:?}", e);
+                            std::process::exit(1);
+                        }
+                    }
+                    BatchProvider::Openai => {
+                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
+                            .expect("Failed to create OpenAI client");
+                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
+                            eprintln!("Error importing OpenAI batches: {:?}", e);
+                            std::process::exit(1);
+                        }
+                    }
                 }
                 println!(
                     "Successfully imported {} batch(es)",

crates/edit_prediction_cli/src/openai_client.rs 🔗

@@ -0,0 +1,664 @@
+use anyhow::Result;
+use http_client::HttpClient;
+use indoc::indoc;
+use open_ai::{
+    MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
+    Response as OpenAiResponse, batches, non_streaming_completion,
+};
+use reqwest_client::ReqwestClient;
+use sqlez::bindable::Bind;
+use sqlez::bindable::StaticColumnCount;
+use sqlez_macros::sql;
+use std::hash::Hash;
+use std::hash::Hasher;
+use std::path::Path;
+use std::sync::{Arc, Mutex};
+
+pub struct PlainOpenAiClient {
+    pub http_client: Arc<dyn HttpClient>,
+    pub api_key: String,
+}
+
+impl PlainOpenAiClient {
+    pub fn new() -> Result<Self> {
+        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+        let api_key = std::env::var("OPENAI_API_KEY")
+            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
+        Ok(Self {
+            http_client,
+            api_key,
+        })
+    }
+
+    pub async fn generate(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: Vec<RequestMessage>,
+    ) -> Result<OpenAiResponse> {
+        let request = OpenAiRequest {
+            model: model.to_string(),
+            messages,
+            stream: false,
+            max_completion_tokens: Some(max_tokens),
+            stop: Vec::new(),
+            temperature: None,
+            tool_choice: None,
+            parallel_tool_calls: None,
+            tools: Vec::new(),
+            prompt_cache_key: None,
+            reasoning_effort: None,
+        };
+
+        let response = non_streaming_completion(
+            self.http_client.as_ref(),
+            OPEN_AI_API_URL,
+            &self.api_key,
+            request,
+        )
+        .await
+        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+        Ok(response)
+    }
+}
+
+pub struct BatchingOpenAiClient {
+    connection: Mutex<sqlez::connection::Connection>,
+    http_client: Arc<dyn HttpClient>,
+    api_key: String,
+}
+
+struct CacheRow {
+    request_hash: String,
+    request: Option<String>,
+    response: Option<String>,
+    batch_id: Option<String>,
+}
+
+impl StaticColumnCount for CacheRow {
+    fn column_count() -> usize {
+        4
+    }
+}
+
+impl Bind for CacheRow {
+    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
+        let next_index = statement.bind(&self.request_hash, start_index)?;
+        let next_index = statement.bind(&self.request, next_index)?;
+        let next_index = statement.bind(&self.response, next_index)?;
+        let next_index = statement.bind(&self.batch_id, next_index)?;
+        Ok(next_index)
+    }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableRequest {
+    model: String,
+    max_tokens: u64,
+    messages: Vec<SerializableMessage>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableMessage {
+    role: String,
+    content: String,
+}
+
+impl BatchingOpenAiClient {
+    fn new(cache_path: &Path) -> Result<Self> {
+        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+        let api_key = std::env::var("OPENAI_API_KEY")
+            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
+
+        let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
+        let mut statement = sqlez::statement::Statement::prepare(
+            &connection,
+            indoc! {"
+                CREATE TABLE IF NOT EXISTS openai_cache (
+                    request_hash TEXT PRIMARY KEY,
+                    request TEXT,
+                    response TEXT,
+                    batch_id TEXT
+                );
+                "},
+        )?;
+        statement.exec()?;
+        drop(statement);
+
+        Ok(Self {
+            connection: Mutex::new(connection),
+            http_client,
+            api_key,
+        })
+    }
+
+    pub fn lookup(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: &[RequestMessage],
+    ) -> Result<Option<OpenAiResponse>> {
+        let request_hash_str = Self::request_hash(model, max_tokens, messages);
+        let connection = self.connection.lock().unwrap();
+        let response: Vec<String> = connection.select_bound(
+            &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
+        )?(request_hash_str.as_str())?;
+        Ok(response
+            .into_iter()
+            .next()
+            .and_then(|text| serde_json::from_str(&text).ok()))
+    }
+
+    pub fn mark_for_batch(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: &[RequestMessage],
+    ) -> Result<()> {
+        let request_hash = Self::request_hash(model, max_tokens, messages);
+
+        let serializable_messages: Vec<SerializableMessage> = messages
+            .iter()
+            .map(|msg| SerializableMessage {
+                role: message_role_to_string(msg),
+                content: message_content_to_string(msg),
+            })
+            .collect();
+
+        let serializable_request = SerializableRequest {
+            model: model.to_string(),
+            max_tokens,
+            messages: serializable_messages,
+        };
+
+        let request = Some(serde_json::to_string(&serializable_request)?);
+        let cache_row = CacheRow {
+            request_hash,
+            request,
+            response: None,
+            batch_id: None,
+        };
+        let connection = self.connection.lock().unwrap();
+        connection.exec_bound::<CacheRow>(sql!(
+            INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
+            cache_row,
+        )
+    }
+
+    async fn generate(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: Vec<RequestMessage>,
+    ) -> Result<Option<OpenAiResponse>> {
+        let response = self.lookup(model, max_tokens, &messages)?;
+        if let Some(response) = response {
+            return Ok(Some(response));
+        }
+
+        self.mark_for_batch(model, max_tokens, &messages)?;
+
+        Ok(None)
+    }
+
+    async fn sync_batches(&self) -> Result<()> {
+        let _batch_ids = self.upload_pending_requests().await?;
+        self.download_finished_batches().await
+    }
+
+    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
+        for batch_id in batch_ids {
+            log::info!("Importing OpenAI batch {}", batch_id);
+
+            let batch_status = batches::retrieve_batch(
+                self.http_client.as_ref(),
+                OPEN_AI_API_URL,
+                &self.api_key,
+                batch_id,
+            )
+            .await
+            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
+
+            log::info!("Batch {} status: {}", batch_id, batch_status.status);
+
+            if batch_status.status != "completed" {
+                log::warn!(
+                    "Batch {} is not completed (status: {}), skipping",
+                    batch_id,
+                    batch_status.status
+                );
+                continue;
+            }
+
+            let output_file_id = batch_status.output_file_id.ok_or_else(|| {
+                anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
+            })?;
+
+            let results_content = batches::download_file(
+                self.http_client.as_ref(),
+                OPEN_AI_API_URL,
+                &self.api_key,
+                &output_file_id,
+            )
+            .await
+            .map_err(|e| {
+                anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
+            })?;
+
+            let results = batches::parse_batch_output(&results_content)
+                .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
+
+            let mut updates: Vec<(String, String, String)> = Vec::new();
+            let mut success_count = 0;
+            let mut error_count = 0;
+
+            for result in results {
+                let request_hash = result
+                    .custom_id
+                    .strip_prefix("req_hash_")
+                    .unwrap_or(&result.custom_id)
+                    .to_string();
+
+                if let Some(response_body) = result.response {
+                    if response_body.status_code == 200 {
+                        let response_json = serde_json::to_string(&response_body.body)?;
+                        updates.push((request_hash, response_json, batch_id.clone()));
+                        success_count += 1;
+                    } else {
+                        log::error!(
+                            "Batch request {} failed with status {}",
+                            request_hash,
+                            response_body.status_code
+                        );
+                        let error_json = serde_json::json!({
+                            "error": {
+                                "type": "http_error",
+                                "status_code": response_body.status_code
+                            }
+                        })
+                        .to_string();
+                        updates.push((request_hash, error_json, batch_id.clone()));
+                        error_count += 1;
+                    }
+                } else if let Some(error) = result.error {
+                    log::error!(
+                        "Batch request {} failed: {}: {}",
+                        request_hash,
+                        error.code,
+                        error.message
+                    );
+                    let error_json = serde_json::json!({
+                        "error": {
+                            "type": error.code,
+                            "message": error.message
+                        }
+                    })
+                    .to_string();
+                    updates.push((request_hash, error_json, batch_id.clone()));
+                    error_count += 1;
+                }
+            }
+
+            let connection = self.connection.lock().unwrap();
+            connection.with_savepoint("batch_import", || {
+                let q = sql!(
+                    INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
+                    VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
+                );
+                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
+                for (request_hash, response_json, batch_id) in &updates {
+                    exec((
+                        request_hash.as_str(),
+                        request_hash.as_str(),
+                        response_json.as_str(),
+                        batch_id.as_str(),
+                    ))?;
+                }
+                Ok(())
+            })?;
+
+            log::info!(
+                "Imported batch {}: {} successful, {} errors",
+                batch_id,
+                success_count,
+                error_count
+            );
+        }
+
+        Ok(())
+    }
+
+    async fn download_finished_batches(&self) -> Result<()> {
+        let batch_ids: Vec<String> = {
+            let connection = self.connection.lock().unwrap();
+            let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
+            connection.select(q)?()?
+        };
+
+        for batch_id in &batch_ids {
+            let batch_status = batches::retrieve_batch(
+                self.http_client.as_ref(),
+                OPEN_AI_API_URL,
+                &self.api_key,
+                batch_id,
+            )
+            .await
+            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+            log::info!("Batch {} status: {}", batch_id, batch_status.status);
+
+            if batch_status.status == "completed" {
+                let output_file_id = match batch_status.output_file_id {
+                    Some(id) => id,
+                    None => {
+                        log::warn!("Batch {} completed but has no output file", batch_id);
+                        continue;
+                    }
+                };
+
+                let results_content = batches::download_file(
+                    self.http_client.as_ref(),
+                    OPEN_AI_API_URL,
+                    &self.api_key,
+                    &output_file_id,
+                )
+                .await
+                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+                let results = batches::parse_batch_output(&results_content)
+                    .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
+
+                let mut updates: Vec<(String, String)> = Vec::new();
+                let mut success_count = 0;
+
+                for result in results {
+                    let request_hash = result
+                        .custom_id
+                        .strip_prefix("req_hash_")
+                        .unwrap_or(&result.custom_id)
+                        .to_string();
+
+                    if let Some(response_body) = result.response {
+                        if response_body.status_code == 200 {
+                            let response_json = serde_json::to_string(&response_body.body)?;
+                            updates.push((response_json, request_hash));
+                            success_count += 1;
+                        } else {
+                            log::error!(
+                                "Batch request {} failed with status {}",
+                                request_hash,
+                                response_body.status_code
+                            );
+                            let error_json = serde_json::json!({
+                                "error": {
+                                    "type": "http_error",
+                                    "status_code": response_body.status_code
+                                }
+                            })
+                            .to_string();
+                            updates.push((error_json, request_hash));
+                        }
+                    } else if let Some(error) = result.error {
+                        log::error!(
+                            "Batch request {} failed: {}: {}",
+                            request_hash,
+                            error.code,
+                            error.message
+                        );
+                        let error_json = serde_json::json!({
+                            "error": {
+                                "type": error.code,
+                                "message": error.message
+                            }
+                        })
+                        .to_string();
+                        updates.push((error_json, request_hash));
+                    }
+                }
+
+                let connection = self.connection.lock().unwrap();
+                connection.with_savepoint("batch_download", || {
+                    let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
+                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
+                    for (response_json, request_hash) in &updates {
+                        exec((response_json.as_str(), request_hash.as_str()))?;
+                    }
+                    Ok(())
+                })?;
+                log::info!("Downloaded {} successful requests", success_count);
+            }
+        }
+
+        Ok(())
+    }
+
+    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
+        const BATCH_CHUNK_SIZE: i32 = 16_000;
+        let mut all_batch_ids = Vec::new();
+        let mut total_uploaded = 0;
+
+        loop {
+            let rows: Vec<(String, String)> = {
+                let connection = self.connection.lock().unwrap();
+                let q = sql!(
+                    SELECT request_hash, request FROM openai_cache
+                    WHERE batch_id IS NULL AND response IS NULL
+                    LIMIT ?
+                );
+                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
+            };
+
+            if rows.is_empty() {
+                break;
+            }
+
+            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
+
+            let mut jsonl_content = String::new();
+            for (hash, request_str) in &rows {
+                let serializable_request: SerializableRequest =
+                    serde_json::from_str(request_str).unwrap();
+
+                let messages: Vec<RequestMessage> = serializable_request
+                    .messages
+                    .into_iter()
+                    .map(|msg| match msg.role.as_str() {
+                        "user" => RequestMessage::User {
+                            content: MessageContent::Plain(msg.content),
+                        },
+                        "assistant" => RequestMessage::Assistant {
+                            content: Some(MessageContent::Plain(msg.content)),
+                            tool_calls: Vec::new(),
+                        },
+                        "system" => RequestMessage::System {
+                            content: MessageContent::Plain(msg.content),
+                        },
+                        _ => RequestMessage::User {
+                            content: MessageContent::Plain(msg.content),
+                        },
+                    })
+                    .collect();
+
+                let request = OpenAiRequest {
+                    model: serializable_request.model,
+                    messages,
+                    stream: false,
+                    max_completion_tokens: Some(serializable_request.max_tokens),
+                    stop: Vec::new(),
+                    temperature: None,
+                    tool_choice: None,
+                    parallel_tool_calls: None,
+                    tools: Vec::new(),
+                    prompt_cache_key: None,
+                    reasoning_effort: None,
+                };
+
+                let custom_id = format!("req_hash_{}", hash);
+                let batch_item = batches::BatchRequestItem::new(custom_id, request);
+                let line = batch_item
+                    .to_jsonl_line()
+                    .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
+                jsonl_content.push_str(&line);
+                jsonl_content.push('\n');
+            }
+
+            let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
+            let file_obj = batches::upload_batch_file(
+                self.http_client.as_ref(),
+                OPEN_AI_API_URL,
+                &self.api_key,
+                &filename,
+                jsonl_content.into_bytes(),
+            )
+            .await
+            .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
+
+            let batch = batches::create_batch(
+                self.http_client.as_ref(),
+                OPEN_AI_API_URL,
+                &self.api_key,
+                batches::CreateBatchRequest::new(file_obj.id),
+            )
+            .await
+            .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
+
+            {
+                let connection = self.connection.lock().unwrap();
+                connection.with_savepoint("batch_upload", || {
+                    let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
+                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
+                    for hash in &request_hashes {
+                        exec((batch.id.as_str(), hash.as_str()))?;
+                    }
+                    Ok(())
+                })?;
+            }
+
+            let batch_len = rows.len();
+            total_uploaded += batch_len;
+            log::info!(
+                "Uploaded batch {} with {} requests ({} total)",
+                batch.id,
+                batch_len,
+                total_uploaded
+            );
+
+            all_batch_ids.push(batch.id);
+        }
+
+        if !all_batch_ids.is_empty() {
+            log::info!(
+                "Finished uploading {} batches with {} total requests",
+                all_batch_ids.len(),
+                total_uploaded
+            );
+        }
+
+        Ok(all_batch_ids)
+    }
+
+    fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String {
+        let mut hasher = std::hash::DefaultHasher::new();
+        "openai".hash(&mut hasher);
+        model.hash(&mut hasher);
+        max_tokens.hash(&mut hasher);
+        for msg in messages {
+            message_content_to_string(msg).hash(&mut hasher);
+        }
+        let request_hash = hasher.finish();
+        format!("{request_hash:016x}")
+    }
+}
+
+fn message_role_to_string(msg: &RequestMessage) -> String {
+    match msg {
+        RequestMessage::User { .. } => "user".to_string(),
+        RequestMessage::Assistant { .. } => "assistant".to_string(),
+        RequestMessage::System { .. } => "system".to_string(),
+        RequestMessage::Tool { .. } => "tool".to_string(),
+    }
+}
+
+fn message_content_to_string(msg: &RequestMessage) -> String {
+    match msg {
+        RequestMessage::User { content } => content_to_string(content),
+        RequestMessage::Assistant { content, .. } => {
+            content.as_ref().map(content_to_string).unwrap_or_default()
+        }
+        RequestMessage::System { content } => content_to_string(content),
+        RequestMessage::Tool { content, .. } => content_to_string(content),
+    }
+}
+
+fn content_to_string(content: &MessageContent) -> String {
+    match content {
+        MessageContent::Plain(text) => text.clone(),
+        MessageContent::Multipart(parts) => parts
+            .iter()
+            .filter_map(|part| match part {
+                open_ai::MessagePart::Text { text } => Some(text.clone()),
+                _ => None,
+            })
+            .collect::<Vec<String>>()
+            .join("\n"),
+    }
+}
+
+pub enum OpenAiClient {
+    Plain(PlainOpenAiClient),
+    Batch(BatchingOpenAiClient),
+    #[allow(dead_code)]
+    Dummy,
+}
+
+impl OpenAiClient {
+    pub fn plain() -> Result<Self> {
+        Ok(Self::Plain(PlainOpenAiClient::new()?))
+    }
+
+    pub fn batch(cache_path: &Path) -> Result<Self> {
+        Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
+    }
+
+    #[allow(dead_code)]
+    pub fn dummy() -> Self {
+        Self::Dummy
+    }
+
+    pub async fn generate(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: Vec<RequestMessage>,
+    ) -> Result<Option<OpenAiResponse>> {
+        match self {
+            OpenAiClient::Plain(plain_client) => plain_client
+                .generate(model, max_tokens, messages)
+                .await
+                .map(Some),
+            OpenAiClient::Batch(batching_client) => {
+                batching_client.generate(model, max_tokens, messages).await
+            }
+            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
+        }
+    }
+
+    pub async fn sync_batches(&self) -> Result<()> {
+        match self {
+            OpenAiClient::Plain(_) => Ok(()),
+            OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
+            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
+        }
+    }
+
+    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
+        match self {
+            OpenAiClient::Plain(_) => {
+                anyhow::bail!("Import batches is only supported with batching client")
+            }
+            OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
+            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
+        }
+    }
+}

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -1,10 +1,11 @@
 use crate::{
-    FormatPromptArgs, PredictArgs, PredictionProvider,
+    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
     anthropic_client::AnthropicClient,
     example::{Example, ExamplePrediction, ExamplePrompt},
     format_prompt::{TeacherPrompt, run_format_prompt},
     headless::EpAppState,
     load_project::run_load_project,
+    openai_client::OpenAiClient,
     paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
     progress::{ExampleProgress, InfoStyle, Step},
     retrieve_context::run_context_retrieval,
@@ -20,9 +21,9 @@ use std::{
         atomic::{AtomicUsize, Ordering::SeqCst},
     },
 };
-use zeta_prompt::ZetaVersion;
 
 static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
+static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 
 pub async fn run_prediction(
     example: &mut Example,
@@ -53,7 +54,7 @@ pub async fn run_prediction(
 
     run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 
-    if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
+    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
         provider
     {
         let _step_progress = example_progress.start(Step::Predict);
@@ -68,7 +69,7 @@ pub async fn run_prediction(
         .await?;
 
         let batched = matches!(provider, PredictionProvider::Teacher(..));
-        return predict_anthropic(example, repetition_count, version, batched).await;
+        return predict_teacher(example, backend, batched).await;
     }
 
     run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -148,6 +149,7 @@ pub async fn run_prediction(
                                 updated_example.prompt.get_or_insert(ExamplePrompt {
                                     input: prompt,
                                     expected_output: String::new(),
+                                    rejected_output: None,
                                     provider,
                                 });
                             }
@@ -255,13 +257,23 @@ pub async fn run_prediction(
     Ok(())
 }
 
+async fn predict_teacher(
+    example: &mut Example,
+    backend: TeacherBackend,
+    batched: bool,
+) -> anyhow::Result<()> {
+    match backend {
+        TeacherBackend::Sonnet45 => predict_anthropic(example, backend, batched).await,
+        TeacherBackend::Gpt52 => predict_openai(example, backend, batched).await,
+    }
+}
+
 async fn predict_anthropic(
     example: &mut Example,
-    _repetition_count: usize,
-    version: ZetaVersion,
+    backend: TeacherBackend,
     batched: bool,
 ) -> anyhow::Result<()> {
-    let llm_model_name = "claude-sonnet-4-5";
+    let llm_model_name = backend.model_name();
     let max_tokens = 16384;
     let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
         let client = if batched {
@@ -300,16 +312,83 @@ async fn predict_anthropic(
         .collect::<Vec<String>>()
         .join("\n");
 
-    let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
+    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
 
     let prediction = ExamplePrediction {
         actual_patch: Some(actual_patch),
         actual_output,
         error: None,
         provider: if batched {
-            PredictionProvider::Teacher(version)
+            PredictionProvider::Teacher(backend)
         } else {
-            PredictionProvider::TeacherNonBatching(version)
+            PredictionProvider::TeacherNonBatching(backend)
+        },
+    };
+
+    example.predictions.push(prediction);
+    Ok(())
+}
+
+async fn predict_openai(
+    example: &mut Example,
+    backend: TeacherBackend,
+    batched: bool,
+) -> anyhow::Result<()> {
+    let llm_model_name = backend.model_name();
+    let max_tokens = 16384;
+    let llm_client = OPENAI_CLIENT.get_or_init(|| {
+        let client = if batched {
+            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
+        } else {
+            OpenAiClient::plain()
+        };
+        client.expect("Failed to create OpenAI client")
+    });
+
+    let prompt = example.prompt.as_ref().context("Prompt is required")?;
+
+    let messages = vec![open_ai::RequestMessage::User {
+        content: open_ai::MessageContent::Plain(prompt.input.clone()),
+    }];
+
+    let Some(response) = llm_client
+        .generate(llm_model_name, max_tokens, messages)
+        .await?
+    else {
+        // Request stashed for batched processing
+        return Ok(());
+    };
+
+    let actual_output = response
+        .choices
+        .into_iter()
+        .filter_map(|choice| match choice.message {
+            open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
+                open_ai::MessageContent::Plain(text) => text,
+                open_ai::MessageContent::Multipart(parts) => parts
+                    .into_iter()
+                    .filter_map(|p| match p {
+                        open_ai::MessagePart::Text { text } => Some(text),
+                        _ => None,
+                    })
+                    .collect::<Vec<_>>()
+                    .join(""),
+            }),
+            _ => None,
+        })
+        .collect::<Vec<String>>()
+        .join("\n");
+
+    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+    let prediction = ExamplePrediction {
+        actual_patch: Some(actual_patch),
+        actual_output,
+        error: None,
+        provider: if batched {
+            PredictionProvider::Teacher(backend)
+        } else {
+            PredictionProvider::TeacherNonBatching(backend)
         },
     };
 
@@ -319,16 +398,28 @@ async fn predict_anthropic(
 
 pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
     match provider {
-        Some(PredictionProvider::Teacher(..)) => {
-            let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
-                AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
-                    .expect("Failed to create Anthropic client")
-            });
-            llm_client
-                .sync_batches()
-                .await
-                .context("Failed to sync batches")?;
-        }
+        Some(PredictionProvider::Teacher(backend)) => match backend {
+            TeacherBackend::Sonnet45 => {
+                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
+                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
+                        .expect("Failed to create Anthropic client")
+                });
+                llm_client
+                    .sync_batches()
+                    .await
+                    .context("Failed to sync Anthropic batches")?;
+            }
+            TeacherBackend::Gpt52 => {
+                let llm_client = OPENAI_CLIENT.get_or_init(|| {
+                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
+                        .expect("Failed to create OpenAI client")
+                });
+                llm_client
+                    .sync_batches()
+                    .await
+                    .context("Failed to sync OpenAI batches")?;
+            }
+        },
         _ => (),
     };
     Ok(())

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -1,22 +1,20 @@
 //! Quality assessment of predictions using LLM-as-a-judge.
 //!
-//! This module uses the Anthropic Batch API to evaluate prediction quality.
-//! Caching is handled by the underlying AnthropicClient.
+//! This module uses LLM Batch APIs to evaluate prediction quality.
+//! Caching is handled by the underlying client.
 
+use crate::BatchProvider;
 use crate::anthropic_client::AnthropicClient;
 use crate::example::Example;
 use crate::format_prompt::extract_cursor_excerpt_from_example;
+use crate::openai_client::OpenAiClient;
 use crate::paths::LLM_CACHE_DB;
 use crate::word_diff::unified_to_word_diff;
-use anthropic::{Message, RequestContent, Role};
 use anyhow::Result;
 use serde::{Deserialize, Serialize};
 use std::io::{BufWriter, Write};
 use std::path::PathBuf;
 
-/// Model to use for QA evaluation.
-const MODEL: &str = "claude-sonnet-4-5";
-
 const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
 
 /// Arguments for the QA command.
@@ -29,6 +27,17 @@ pub struct QaArgs {
     /// Wait for batch to complete (polls every 30s)
     #[clap(long)]
     pub wait: bool,
+
+    /// Which LLM provider to use (anthropic or openai)
+    #[clap(long, default_value = "anthropic")]
+    pub backend: BatchProvider,
+}
+
+fn model_for_backend(backend: BatchProvider) -> &'static str {
+    match backend {
+        BatchProvider::Anthropic => "claude-sonnet-4-5",
+        BatchProvider::Openai => "gpt-5.2",
+    }
 }
 
 /// Result of QA evaluation for a single prediction.
@@ -147,19 +156,101 @@ fn parse_response(response_text: &str) -> QaResult {
     }
 }
 
+enum QaClient {
+    Anthropic(AnthropicClient),
+    OpenAi(OpenAiClient),
+}
+
+impl QaClient {
+    async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
+        match self {
+            QaClient::Anthropic(client) => {
+                let messages = vec![anthropic::Message {
+                    role: anthropic::Role::User,
+                    content: vec![anthropic::RequestContent::Text {
+                        text: prompt.to_string(),
+                        cache_control: None,
+                    }],
+                }];
+                let response = client.generate(model, max_tokens, messages).await?;
+                Ok(response.map(|r| {
+                    r.content
+                        .iter()
+                        .filter_map(|c| match c {
+                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join("")
+                }))
+            }
+            QaClient::OpenAi(client) => {
+                let messages = vec![open_ai::RequestMessage::User {
+                    content: open_ai::MessageContent::Plain(prompt.to_string()),
+                }];
+                let response = client.generate(model, max_tokens, messages).await?;
+                Ok(response.map(|r| {
+                    r.choices
+                        .into_iter()
+                        .filter_map(|choice| match choice.message {
+                            open_ai::RequestMessage::Assistant { content, .. } => {
+                                content.map(|c| match c {
+                                    open_ai::MessageContent::Plain(text) => text,
+                                    open_ai::MessageContent::Multipart(parts) => parts
+                                        .into_iter()
+                                        .filter_map(|p| match p {
+                                            open_ai::MessagePart::Text { text } => Some(text),
+                                            _ => None,
+                                        })
+                                        .collect::<Vec<_>>()
+                                        .join(""),
+                                })
+                            }
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join("")
+                }))
+            }
+        }
+    }
+
+    async fn sync_batches(&self) -> Result<()> {
+        match self {
+            QaClient::Anthropic(client) => client.sync_batches().await,
+            QaClient::OpenAi(client) => client.sync_batches().await,
+        }
+    }
+}
+
 /// Run the QA evaluation on a set of examples.
 pub async fn run_qa(
     examples: &mut [Example],
     args: &QaArgs,
     output_path: Option<&PathBuf>,
 ) -> Result<()> {
-    let client = if args.no_batch {
-        AnthropicClient::plain()?
-    } else {
-        AnthropicClient::batch(&LLM_CACHE_DB)?
+    let model = model_for_backend(args.backend);
+    let client = match args.backend {
+        BatchProvider::Anthropic => {
+            if args.no_batch {
+                QaClient::Anthropic(AnthropicClient::plain()?)
+            } else {
+                QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
+            }
+        }
+        BatchProvider::Openai => {
+            if args.no_batch {
+                QaClient::OpenAi(OpenAiClient::plain()?)
+            } else {
+                QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
+            }
+        }
     };
 
-    eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
+    eprintln!(
+        "Using model: {}, backend: {:?}, batching: {}",
+        model, args.backend, !args.no_batch
+    );
 
     // First pass: send requests (client handles caching internally)
     let mut prompts: Vec<(usize, String)> = Vec::new();
@@ -187,54 +278,16 @@ pub async fn run_qa(
         for (i, (idx, prompt)) in prompts.iter().enumerate() {
             eprint!("\rProcessing {}/{}", i + 1, prompts.len());
 
-            let messages = vec![Message {
-                role: Role::User,
-                content: vec![RequestContent::Text {
-                    text: prompt.clone(),
-                    cache_control: None,
-                }],
-            }];
-
-            let response = client.generate(MODEL, 1024, messages).await?;
-            let result = response.map(|r| {
-                let text = r
-                    .content
-                    .iter()
-                    .filter_map(|c| match c {
-                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
-                        _ => None,
-                    })
-                    .collect::<Vec<_>>()
-                    .join("");
-                parse_response(&text)
-            });
+            let response = client.generate(model, 1024, prompt).await?;
+            let result = response.map(|text| parse_response(&text));
             results.push((*idx, result));
         }
         eprintln!();
     } else {
         // Queue all for batching
         for (idx, prompt) in &prompts {
-            let messages = vec![Message {
-                role: Role::User,
-                content: vec![RequestContent::Text {
-                    text: prompt.clone(),
-                    cache_control: None,
-                }],
-            }];
-
-            let response = client.generate(MODEL, 1024, messages).await?;
-            let result = response.map(|r| {
-                let text = r
-                    .content
-                    .iter()
-                    .filter_map(|c| match c {
-                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
-                        _ => None,
-                    })
-                    .collect::<Vec<_>>()
-                    .join("");
-                parse_response(&text)
-            });
+            let response = client.generate(model, 1024, prompt).await?;
+            let result = response.map(|text| parse_response(&text));
             results.push((*idx, result));
         }
 
@@ -251,27 +304,8 @@ pub async fn run_qa(
                 let mut all_done = true;
                 for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
                     if results[result_idx].1.is_none() {
-                        let messages = vec![Message {
-                            role: Role::User,
-                            content: vec![RequestContent::Text {
-                                text: prompt.clone(),
-                                cache_control: None,
-                            }],
-                        }];
-
-                        let response = client.generate(MODEL, 1024, messages).await?;
-                        if let Some(r) = response {
-                            let text = r
-                                .content
-                                .iter()
-                                .filter_map(|c| match c {
-                                    anthropic::ResponseContent::Text { text } => {
-                                        Some(text.as_str())
-                                    }
-                                    _ => None,
-                                })
-                                .collect::<Vec<_>>()
-                                .join("");
+                        let response = client.generate(model, 1024, prompt).await?;
+                        if let Some(text) = response {
                             results[result_idx] = (*idx, Some(parse_response(&text)));
                         } else {
                             all_done = false;

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -4,20 +4,18 @@
 //! predictions that need improvement (based on reverts_edits or low confidence),
 //! and uses an LLM to generate improved predictions.
 
+use crate::BatchProvider;
 use crate::PredictionProvider;
 use crate::anthropic_client::AnthropicClient;
 use crate::example::{Example, ExamplePrediction};
 use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
+use crate::openai_client::OpenAiClient;
 use crate::paths::LLM_CACHE_DB;
 use crate::word_diff::unified_to_word_diff;
-use anthropic::{Message, RequestContent, Role};
 use anyhow::Result;
 use std::io::{BufWriter, Write};
 use std::path::PathBuf;
 
-/// Model to use for repair.
-const MODEL: &str = "claude-sonnet-4-5";
-
 const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md");
 
 /// Arguments for the repair command.
@@ -34,6 +32,17 @@ pub struct RepairArgs {
     /// Confidence threshold: repair predictions with confidence <= this value (1-5)
     #[clap(long, default_value = "2")]
     pub confidence_threshold: u8,
+
+    /// Which LLM provider to use (anthropic or openai)
+    #[clap(long, default_value = "anthropic")]
+    pub backend: BatchProvider,
+}
+
+fn model_for_backend(backend: BatchProvider) -> &'static str {
+    match backend {
+        BatchProvider::Anthropic => "claude-sonnet-4-5",
+        BatchProvider::Openai => "gpt-5.2",
+    }
 }
 
 /// Build the repair prompt for an example that needs improvement.
@@ -127,21 +136,100 @@ fn parse_repair_response(example: &Example, response_text: &str) -> Result<Examp
     })
 }
 
+enum RepairClient {
+    Anthropic(AnthropicClient),
+    OpenAi(OpenAiClient),
+}
+
+impl RepairClient {
+    async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
+        match self {
+            RepairClient::Anthropic(client) => {
+                let messages = vec![anthropic::Message {
+                    role: anthropic::Role::User,
+                    content: vec![anthropic::RequestContent::Text {
+                        text: prompt.to_string(),
+                        cache_control: None,
+                    }],
+                }];
+                let response = client.generate(model, max_tokens, messages).await?;
+                Ok(response.map(|r| {
+                    r.content
+                        .iter()
+                        .filter_map(|c| match c {
+                            anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join("")
+                }))
+            }
+            RepairClient::OpenAi(client) => {
+                let messages = vec![open_ai::RequestMessage::User {
+                    content: open_ai::MessageContent::Plain(prompt.to_string()),
+                }];
+                let response = client.generate(model, max_tokens, messages).await?;
+                Ok(response.map(|r| {
+                    r.choices
+                        .into_iter()
+                        .filter_map(|choice| match choice.message {
+                            open_ai::RequestMessage::Assistant { content, .. } => {
+                                content.map(|c| match c {
+                                    open_ai::MessageContent::Plain(text) => text,
+                                    open_ai::MessageContent::Multipart(parts) => parts
+                                        .into_iter()
+                                        .filter_map(|p| match p {
+                                            open_ai::MessagePart::Text { text } => Some(text),
+                                            _ => None,
+                                        })
+                                        .collect::<Vec<_>>()
+                                        .join(""),
+                                })
+                            }
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join("")
+                }))
+            }
+        }
+    }
+
+    async fn sync_batches(&self) -> Result<()> {
+        match self {
+            RepairClient::Anthropic(client) => client.sync_batches().await,
+            RepairClient::OpenAi(client) => client.sync_batches().await,
+        }
+    }
+}
+
 /// Run the repair process on a set of examples.
 pub async fn run_repair(
     examples: &mut [Example],
     args: &RepairArgs,
     output_path: Option<&PathBuf>,
 ) -> Result<()> {
-    let client = if args.no_batch {
-        AnthropicClient::plain()?
-    } else {
-        AnthropicClient::batch(&LLM_CACHE_DB)?
+    let model = model_for_backend(args.backend);
+    let client = match args.backend {
+        BatchProvider::Anthropic => {
+            if args.no_batch {
+                RepairClient::Anthropic(AnthropicClient::plain()?)
+            } else {
+                RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
+            }
+        }
+        BatchProvider::Openai => {
+            if args.no_batch {
+                RepairClient::OpenAi(OpenAiClient::plain()?)
+            } else {
+                RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
+            }
+        }
     };
 
     eprintln!(
-        "Using model: {}, batching: {}, confidence_threshold: {}",
-        MODEL, !args.no_batch, args.confidence_threshold
+        "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}",
+        model, args.backend, !args.no_batch, args.confidence_threshold
     );
 
     // First pass: identify examples that need repair and build prompts
@@ -185,51 +273,15 @@ pub async fn run_repair(
         for (i, (idx, prompt)) in repair_items.iter().enumerate() {
             eprint!("\rProcessing {}/{}", i + 1, repair_items.len());
 
-            let messages = vec![Message {
-                role: Role::User,
-                content: vec![RequestContent::Text {
-                    text: prompt.clone(),
-                    cache_control: None,
-                }],
-            }];
-
-            let response = client.generate(MODEL, 16384, messages).await?;
-            let result = response.map(|r| {
-                r.content
-                    .iter()
-                    .filter_map(|c| match c {
-                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
-                        _ => None,
-                    })
-                    .collect::<Vec<_>>()
-                    .join("")
-            });
-            results.push((*idx, result));
+            let response = client.generate(model, 16384, prompt).await?;
+            results.push((*idx, response));
         }
         eprintln!();
     } else {
         // Queue all for batching
         for (idx, prompt) in &repair_items {
-            let messages = vec![Message {
-                role: Role::User,
-                content: vec![RequestContent::Text {
-                    text: prompt.clone(),
-                    cache_control: None,
-                }],
-            }];
-
-            let response = client.generate(MODEL, 16384, messages).await?;
-            let result = response.map(|r| {
-                r.content
-                    .iter()
-                    .filter_map(|c| match c {
-                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
-                        _ => None,
-                    })
-                    .collect::<Vec<_>>()
-                    .join("")
-            });
-            results.push((*idx, result));
+            let response = client.generate(model, 16384, prompt).await?;
+            results.push((*idx, response));
         }
 
         // Sync batches (upload pending, download finished)
@@ -245,27 +297,8 @@ pub async fn run_repair(
                 let mut all_done = true;
                 for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() {
                     if results[result_idx].1.is_none() {
-                        let messages = vec![Message {
-                            role: Role::User,
-                            content: vec![RequestContent::Text {
-                                text: prompt.clone(),
-                                cache_control: None,
-                            }],
-                        }];
-
-                        let response = client.generate(MODEL, 16384, messages).await?;
-                        if let Some(r) = response {
-                            let text = r
-                                .content
-                                .iter()
-                                .filter_map(|c| match c {
-                                    anthropic::ResponseContent::Text { text } => {
-                                        Some(text.as_str())
-                                    }
-                                    _ => None,
-                                })
-                                .collect::<Vec<_>>()
-                                .join("");
+                        let response = client.generate(model, 16384, prompt).await?;
+                        if let Some(text) = response {
                             results[result_idx] = (*idx, Some(text));
                         } else {
                             all_done = false;

crates/open_ai/Cargo.toml 🔗

@@ -19,6 +19,7 @@ schemars = ["dep:schemars"]
 anyhow.workspace = true
 futures.workspace = true
 http_client.workspace = true
+rand.workspace = true
 schemars = { workspace = true, optional = true }
 log.workspace = true
 serde.workspace = true

crates/open_ai/src/batches.rs 🔗

@@ -0,0 +1,332 @@
+use anyhow::Result;
+use futures::AsyncReadExt;
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+
+use crate::{Request, RequestError, Response};
+
+/// A single request within a batch
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchRequestItem {
+    pub custom_id: String,
+    pub method: String,
+    pub url: String,
+    pub body: Request,
+}
+
+impl BatchRequestItem {
+    pub fn new(custom_id: String, request: Request) -> Self {
+        Self {
+            custom_id,
+            method: "POST".to_string(),
+            url: "/v1/chat/completions".to_string(),
+            body: request,
+        }
+    }
+
+    pub fn to_jsonl_line(&self) -> Result<String, serde_json::Error> {
+        serde_json::to_string(self)
+    }
+}
+
+/// Request to create a batch
+#[derive(Debug, Serialize)]
+pub struct CreateBatchRequest {
+    pub input_file_id: String,
+    pub endpoint: String,
+    pub completion_window: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub metadata: Option<serde_json::Value>,
+}
+
+impl CreateBatchRequest {
+    pub fn new(input_file_id: String) -> Self {
+        Self {
+            input_file_id,
+            endpoint: "/v1/chat/completions".to_string(),
+            completion_window: "24h".to_string(),
+            metadata: None,
+        }
+    }
+}
+
+/// Response from batch creation or retrieval
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Batch {
+    pub id: String,
+    pub object: String,
+    pub endpoint: String,
+    pub input_file_id: String,
+    pub completion_window: String,
+    pub status: String,
+    pub output_file_id: Option<String>,
+    pub error_file_id: Option<String>,
+    pub created_at: u64,
+    #[serde(default)]
+    pub in_progress_at: Option<u64>,
+    #[serde(default)]
+    pub expires_at: Option<u64>,
+    #[serde(default)]
+    pub finalizing_at: Option<u64>,
+    #[serde(default)]
+    pub completed_at: Option<u64>,
+    #[serde(default)]
+    pub failed_at: Option<u64>,
+    #[serde(default)]
+    pub expired_at: Option<u64>,
+    #[serde(default)]
+    pub cancelling_at: Option<u64>,
+    #[serde(default)]
+    pub cancelled_at: Option<u64>,
+    #[serde(default)]
+    pub request_counts: Option<BatchRequestCounts>,
+    #[serde(default)]
+    pub metadata: Option<serde_json::Value>,
+}
+
+#[derive(Debug, Serialize, Deserialize, Default)]
+pub struct BatchRequestCounts {
+    pub total: u64,
+    pub completed: u64,
+    pub failed: u64,
+}
+
+/// Response from file upload
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FileObject {
+    pub id: String,
+    pub object: String,
+    pub bytes: u64,
+    pub created_at: u64,
+    pub filename: String,
+    pub purpose: String,
+}
+
+/// Individual result from batch output
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchOutputItem {
+    pub id: String,
+    pub custom_id: String,
+    pub response: Option<BatchResponseBody>,
+    pub error: Option<BatchError>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchResponseBody {
+    pub status_code: u16,
+    pub body: Response,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchError {
+    pub code: String,
+    pub message: String,
+}
+
+/// Upload a JSONL file for batch processing
+pub async fn upload_batch_file(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    filename: &str,
+    content: Vec<u8>,
+) -> Result<FileObject, RequestError> {
+    let uri = format!("{api_url}/files");
+
+    let boundary = format!("----WebKitFormBoundary{:x}", rand::random::<u64>());
+
+    let mut body = Vec::new();
+    body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
+    body.extend_from_slice(b"Content-Disposition: form-data; name=\"purpose\"\r\n\r\n");
+    body.extend_from_slice(b"batch\r\n");
+    body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
+    body.extend_from_slice(
+        format!("Content-Disposition: form-data; name=\"file\"; filename=\"{filename}\"\r\n")
+            .as_bytes(),
+    );
+    body.extend_from_slice(b"Content-Type: application/jsonl\r\n\r\n");
+    body.extend_from_slice(&content);
+    body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes());
+
+    let request = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Authorization", format!("Bearer {}", api_key.trim()))
+        .header(
+            "Content-Type",
+            format!("multipart/form-data; boundary={boundary}"),
+        )
+        .body(AsyncBody::from(body))
+        .map_err(|e| RequestError::Other(e.into()))?;
+
+    let mut response = client.send(request).await?;
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+    } else {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        Err(RequestError::HttpResponseError {
+            provider: "openai".to_owned(),
+            status_code: response.status(),
+            body,
+            headers: response.headers().clone(),
+        })
+    }
+}
+
+/// Create a batch from an uploaded file
+pub async fn create_batch(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: CreateBatchRequest,
+) -> Result<Batch, RequestError> {
+    let uri = format!("{api_url}/batches");
+
+    let serialized = serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?;
+
+    let request = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Authorization", format!("Bearer {}", api_key.trim()))
+        .header("Content-Type", "application/json")
+        .body(AsyncBody::from(serialized))
+        .map_err(|e| RequestError::Other(e.into()))?;
+
+    let mut response = client.send(request).await?;
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+    } else {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        Err(RequestError::HttpResponseError {
+            provider: "openai".to_owned(),
+            status_code: response.status(),
+            body,
+            headers: response.headers().clone(),
+        })
+    }
+}
+
+/// Retrieve batch status
+pub async fn retrieve_batch(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    batch_id: &str,
+) -> Result<Batch, RequestError> {
+    let uri = format!("{api_url}/batches/{batch_id}");
+
+    let request = HttpRequest::builder()
+        .method(Method::GET)
+        .uri(uri)
+        .header("Authorization", format!("Bearer {}", api_key.trim()))
+        .body(AsyncBody::default())
+        .map_err(|e| RequestError::Other(e.into()))?;
+
+    let mut response = client.send(request).await?;
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+    } else {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        Err(RequestError::HttpResponseError {
+            provider: "openai".to_owned(),
+            status_code: response.status(),
+            body,
+            headers: response.headers().clone(),
+        })
+    }
+}
+
+/// Download file content (for batch results)
+pub async fn download_file(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    file_id: &str,
+) -> Result<String, RequestError> {
+    let uri = format!("{api_url}/files/{file_id}/content");
+
+    let request = HttpRequest::builder()
+        .method(Method::GET)
+        .uri(uri)
+        .header("Authorization", format!("Bearer {}", api_key.trim()))
+        .body(AsyncBody::default())
+        .map_err(|e| RequestError::Other(e.into()))?;
+
+    let mut response = client.send(request).await?;
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        Ok(body)
+    } else {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        Err(RequestError::HttpResponseError {
+            provider: "openai".to_owned(),
+            status_code: response.status(),
+            body,
+            headers: response.headers().clone(),
+        })
+    }
+}
+
+/// Parse batch output JSONL into individual results
+pub fn parse_batch_output(content: &str) -> Result<Vec<BatchOutputItem>, serde_json::Error> {
+    content
+        .lines()
+        .filter(|line| !line.trim().is_empty())
+        .map(|line| serde_json::from_str(line))
+        .collect()
+}

crates/open_ai/src/open_ai.rs 🔗

@@ -1,3 +1,4 @@
+pub mod batches;
 pub mod responses;
 
 use anyhow::{Context as _, Result, anyhow};
@@ -529,6 +530,52 @@ pub struct ResponseStreamEvent {
     pub usage: Option<Usage>,
 }
 
+pub async fn non_streaming_completion(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: Request,
+) -> Result<Response, RequestError> {
+    let uri = format!("{api_url}/chat/completions");
+    let request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key.trim()));
+
+    let request = request_builder
+        .body(AsyncBody::from(
+            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
+        ))
+        .map_err(|e| RequestError::Other(e.into()))?;
+
+    let mut response = client.send(request).await?;
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+    } else {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        Err(RequestError::HttpResponseError {
+            provider: "openai".to_owned(),
+            status_code: response.status(),
+            body,
+            headers: response.headers().clone(),
+        })
+    }
+}
+
 pub async fn stream_completion(
     client: &dyn HttpClient,
     provider_name: &str,