diff --git a/Cargo.lock b/Cargo.lock index 2faee6057d273e1b83d39c4e245f7f31a98c9fde..4ed076538df56f295adafca2af9e2441286aedcc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5338,6 +5338,7 @@ dependencies = [ "libc", "log", "node_runtime", + "open_ai", "paths", "pretty_assertions", "project", @@ -11140,6 +11141,7 @@ dependencies = [ "futures 0.3.31", "http_client", "log", + "rand 0.9.2", "schemars", "serde", "serde_json", diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index fde5876139c1d407d485e76d748c1a4763b5878a..1cb79add3019441c046dc4f9cb3ab0e6cd0b914d 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -37,6 +37,7 @@ languages = { workspace = true, features = ["load-grammars"] } libc.workspace = true log.workspace = true node_runtime.workspace = true +open_ai.workspace = true paths.workspace = true project.workspace = true diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs index 21b255d5f99dc00e5264ffe901cced1352515fa1..bed15c347dc619e772350469a19500ebc18a6da2 100644 --- a/crates/edit_prediction_cli/src/distill.rs +++ b/crates/edit_prediction_cli/src/distill.rs @@ -1,17 +1,27 @@ use anyhow::Result; -use std::mem; -use crate::example::Example; +use crate::{PredictionProvider, example::Example}; pub async fn run_distill(example: &mut Example) -> Result<()> { - let predictions = mem::take(&mut example.predictions) + let has_repair = example + .predictions + .iter() + .find(|p| p.provider == PredictionProvider::Repair); + let predictions = if let Some(has_repair) = has_repair { + vec![has_repair] + } else { + example.predictions.iter().collect() + }; + + let expected_patches = predictions .into_iter() - .filter_map(|p| p.actual_patch) + .filter_map(|p| p.actual_patch.clone()) .collect(); - example.spec.expected_patches = predictions; + example.spec.expected_patches = expected_patches; example.prompt = None; example.predictions = Vec::new(); example.score = Vec::new(); + example.qa = Vec::new(); Ok(()) } diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 381e914969db312df5dafcf1df1ab6e6d7ba0cc8..ea0d5a897ac3a83a18d8a75f1b50bcab74a90cdd 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -73,6 +73,7 @@ pub struct ExamplePromptInputs { pub struct ExamplePrompt { pub input: String, pub expected_output: String, + pub rejected_output: Option, // For DPO pub provider: PredictionProvider, } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 75559dd2e7689b8937a5ab2e4d71386167e94c7d..9813615540f1b6d58dfabf558fd18526e8d38d1d 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -48,27 +48,31 @@ pub async fn run_format_prompt( let snapshot = cx.background_spawn(snapshot_fut).await; match args.provider { - PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => { + PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => { step_progress.set_substatus("formatting teacher prompt"); let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( cursor_point, &snapshot, - edit_prediction::zeta2::max_editable_tokens(version), + edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()), edit_prediction::zeta2::MAX_CONTEXT_TOKENS, ); let editable_range = editable_range.to_offset(&snapshot); let context_range = context_range.to_offset(&snapshot); let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range); + let expected_output = example + .spec + .expected_patches + .first() + .cloned() + .unwrap_or_default(); + let rejected_output = example.spec.rejected_patch.clone(); + example.prompt = Some(ExamplePrompt { input: prompt, - expected_output: example - .spec - .expected_patches - .first() - .cloned() - .unwrap_or_default(), + expected_output, + rejected_output, provider: args.provider, }); } @@ -107,9 +111,16 @@ pub async fn run_format_prompt( .clone(), version, )?; + let rejected_output = example + .spec + .rejected_patch + .as_ref() + .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok()); + example.prompt = Some(ExamplePrompt { input: prompt, expected_output, + rejected_output, provider: args.provider, }); } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 821dc6d86be489ddd2eb086e074f7ce92af8ef6f..38cfc66f7e5a7215be72c81473bdbac5241e402a 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -7,6 +7,7 @@ mod git; mod headless; mod load_project; mod metrics; +mod openai_client; mod parse_output; mod paths; mod predict; @@ -241,14 +242,51 @@ struct EvalArgs { summary_json: Option, } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum TeacherBackend { + Sonnet45, + Gpt52, +} + +impl std::fmt::Display for TeacherBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TeacherBackend::Sonnet45 => write!(f, "sonnet45"), + TeacherBackend::Gpt52 => write!(f, "gpt52"), + } + } +} + +impl std::str::FromStr for TeacherBackend { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45), + "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52), + "v0114180editableregion" => Ok(TeacherBackend::Sonnet45), + _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"), + } + } +} + +impl TeacherBackend { + pub fn model_name(&self) -> &'static str { + match self { + TeacherBackend::Sonnet45 => "claude-sonnet-4-5", + TeacherBackend::Gpt52 => "gpt-5.2", + } + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] enum PredictionProvider { Sweep, Mercury, Zeta1, Zeta2(ZetaVersion), - Teacher(ZetaVersion), - TeacherNonBatching(ZetaVersion), + Teacher(TeacherBackend), + TeacherNonBatching(TeacherBackend), Repair, } @@ -265,9 +303,9 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::Mercury => write!(f, "mercury"), PredictionProvider::Zeta1 => write!(f, "zeta1"), PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"), - PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"), - PredictionProvider::TeacherNonBatching(version) => { - write!(f, "teacher-non-batching:{version}") + PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"), + PredictionProvider::TeacherNonBatching(backend) => { + write!(f, "teacher-non-batching:{backend}") } PredictionProvider::Repair => write!(f, "repair"), } @@ -277,28 +315,38 @@ impl std::fmt::Display for PredictionProvider { impl std::str::FromStr for PredictionProvider { type Err = anyhow::Error; - fn from_str(mut s: &str) -> Result { - let mut version = ZetaVersion::default(); - if let Some((first, second)) = s.split_once(':') { - version = ZetaVersion::parse(second)?; - s = first; - } + fn from_str(s: &str) -> Result { + let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a))); - let s_lower = s.to_lowercase(); - match s_lower.as_str() { + let provider_lower = provider.to_lowercase(); + match provider_lower.as_str() { "sweep" => Ok(PredictionProvider::Sweep), "mercury" => Ok(PredictionProvider::Mercury), "zeta1" => Ok(PredictionProvider::Zeta1), - "zeta2" => Ok(PredictionProvider::Zeta2(version)), - "teacher" => Ok(PredictionProvider::Teacher(version)), + "zeta2" => { + let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default(); + Ok(PredictionProvider::Zeta2(version)) + } + "teacher" => { + let backend = arg + .map(|a| a.parse()) + .transpose()? + .unwrap_or(TeacherBackend::Sonnet45); + Ok(PredictionProvider::Teacher(backend)) + } "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => { - Ok(PredictionProvider::TeacherNonBatching(version)) + let backend = arg + .map(|a| a.parse()) + .transpose()? + .unwrap_or(TeacherBackend::Sonnet45); + Ok(PredictionProvider::TeacherNonBatching(backend)) } "repair" => Ok(PredictionProvider::Repair), _ => { anyhow::bail!( - "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher-non-batching, repair\n\ + "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher:, teacher-non-batching, repair\n\ For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ + For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\ Available zeta versions:\n{}", ZetaVersion::options_as_string() ) @@ -347,9 +395,18 @@ struct SynthesizeArgs { #[derive(Debug, Args, Clone)] struct ImportBatchArgs { - /// Anthropic batch IDs to import (e.g., msgbatch_xxx) + /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI) #[clap(long, required = true, num_args = 1..)] batch_ids: Vec, + /// Which provider's batches to import (anthropic or openai) + #[clap(long, default_value = "anthropic")] + provider: BatchProvider, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] +enum BatchProvider { + Anthropic, + Openai, } impl EpArgs { @@ -537,11 +594,23 @@ fn main() { match &command { Command::ImportBatch(import_args) => { smol::block_on(async { - let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB) - .expect("Failed to create Anthropic client"); - if let Err(e) = client.import_batches(&import_args.batch_ids).await { - eprintln!("Error importing batches: {:?}", e); - std::process::exit(1); + match import_args.provider { + BatchProvider::Anthropic => { + let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB) + .expect("Failed to create Anthropic client"); + if let Err(e) = client.import_batches(&import_args.batch_ids).await { + eprintln!("Error importing Anthropic batches: {:?}", e); + std::process::exit(1); + } + } + BatchProvider::Openai => { + let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB) + .expect("Failed to create OpenAI client"); + if let Err(e) = client.import_batches(&import_args.batch_ids).await { + eprintln!("Error importing OpenAI batches: {:?}", e); + std::process::exit(1); + } + } } println!( "Successfully imported {} batch(es)", diff --git a/crates/edit_prediction_cli/src/openai_client.rs b/crates/edit_prediction_cli/src/openai_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..ad402c3472b238d0d822c9564d449c63d581fa53 --- /dev/null +++ b/crates/edit_prediction_cli/src/openai_client.rs @@ -0,0 +1,664 @@ +use anyhow::Result; +use http_client::HttpClient; +use indoc::indoc; +use open_ai::{ + MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage, + Response as OpenAiResponse, batches, non_streaming_completion, +}; +use reqwest_client::ReqwestClient; +use sqlez::bindable::Bind; +use sqlez::bindable::StaticColumnCount; +use sqlez_macros::sql; +use std::hash::Hash; +use std::hash::Hasher; +use std::path::Path; +use std::sync::{Arc, Mutex}; + +pub struct PlainOpenAiClient { + pub http_client: Arc, + pub api_key: String, +} + +impl PlainOpenAiClient { + pub fn new() -> Result { + let http_client: Arc = Arc::new(ReqwestClient::new()); + let api_key = std::env::var("OPENAI_API_KEY") + .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?; + Ok(Self { + http_client, + api_key, + }) + } + + pub async fn generate( + &self, + model: &str, + max_tokens: u64, + messages: Vec, + ) -> Result { + let request = OpenAiRequest { + model: model.to_string(), + messages, + stream: false, + max_completion_tokens: Some(max_tokens), + stop: Vec::new(), + temperature: None, + tool_choice: None, + parallel_tool_calls: None, + tools: Vec::new(), + prompt_cache_key: None, + reasoning_effort: None, + }; + + let response = non_streaming_completion( + self.http_client.as_ref(), + OPEN_AI_API_URL, + &self.api_key, + request, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + Ok(response) + } +} + +pub struct BatchingOpenAiClient { + connection: Mutex, + http_client: Arc, + api_key: String, +} + +struct CacheRow { + request_hash: String, + request: Option, + response: Option, + batch_id: Option, +} + +impl StaticColumnCount for CacheRow { + fn column_count() -> usize { + 4 + } +} + +impl Bind for CacheRow { + fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result { + let next_index = statement.bind(&self.request_hash, start_index)?; + let next_index = statement.bind(&self.request, next_index)?; + let next_index = statement.bind(&self.response, next_index)?; + let next_index = statement.bind(&self.batch_id, next_index)?; + Ok(next_index) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct SerializableRequest { + model: String, + max_tokens: u64, + messages: Vec, +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct SerializableMessage { + role: String, + content: String, +} + +impl BatchingOpenAiClient { + fn new(cache_path: &Path) -> Result { + let http_client: Arc = Arc::new(ReqwestClient::new()); + let api_key = std::env::var("OPENAI_API_KEY") + .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?; + + let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap()); + let mut statement = sqlez::statement::Statement::prepare( + &connection, + indoc! {" + CREATE TABLE IF NOT EXISTS openai_cache ( + request_hash TEXT PRIMARY KEY, + request TEXT, + response TEXT, + batch_id TEXT + ); + "}, + )?; + statement.exec()?; + drop(statement); + + Ok(Self { + connection: Mutex::new(connection), + http_client, + api_key, + }) + } + + pub fn lookup( + &self, + model: &str, + max_tokens: u64, + messages: &[RequestMessage], + ) -> Result> { + let request_hash_str = Self::request_hash(model, max_tokens, messages); + let connection = self.connection.lock().unwrap(); + let response: Vec = connection.select_bound( + &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;), + )?(request_hash_str.as_str())?; + Ok(response + .into_iter() + .next() + .and_then(|text| serde_json::from_str(&text).ok())) + } + + pub fn mark_for_batch( + &self, + model: &str, + max_tokens: u64, + messages: &[RequestMessage], + ) -> Result<()> { + let request_hash = Self::request_hash(model, max_tokens, messages); + + let serializable_messages: Vec = messages + .iter() + .map(|msg| SerializableMessage { + role: message_role_to_string(msg), + content: message_content_to_string(msg), + }) + .collect(); + + let serializable_request = SerializableRequest { + model: model.to_string(), + max_tokens, + messages: serializable_messages, + }; + + let request = Some(serde_json::to_string(&serializable_request)?); + let cache_row = CacheRow { + request_hash, + request, + response: None, + batch_id: None, + }; + let connection = self.connection.lock().unwrap(); + connection.exec_bound::(sql!( + INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?( + cache_row, + ) + } + + async fn generate( + &self, + model: &str, + max_tokens: u64, + messages: Vec, + ) -> Result> { + let response = self.lookup(model, max_tokens, &messages)?; + if let Some(response) = response { + return Ok(Some(response)); + } + + self.mark_for_batch(model, max_tokens, &messages)?; + + Ok(None) + } + + async fn sync_batches(&self) -> Result<()> { + let _batch_ids = self.upload_pending_requests().await?; + self.download_finished_batches().await + } + + pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> { + for batch_id in batch_ids { + log::info!("Importing OpenAI batch {}", batch_id); + + let batch_status = batches::retrieve_batch( + self.http_client.as_ref(), + OPEN_AI_API_URL, + &self.api_key, + batch_id, + ) + .await + .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?; + + log::info!("Batch {} status: {}", batch_id, batch_status.status); + + if batch_status.status != "completed" { + log::warn!( + "Batch {} is not completed (status: {}), skipping", + batch_id, + batch_status.status + ); + continue; + } + + let output_file_id = batch_status.output_file_id.ok_or_else(|| { + anyhow::anyhow!("Batch {} completed but has no output file", batch_id) + })?; + + let results_content = batches::download_file( + self.http_client.as_ref(), + OPEN_AI_API_URL, + &self.api_key, + &output_file_id, + ) + .await + .map_err(|e| { + anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e) + })?; + + let results = batches::parse_batch_output(&results_content) + .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?; + + let mut updates: Vec<(String, String, String)> = Vec::new(); + let mut success_count = 0; + let mut error_count = 0; + + for result in results { + let request_hash = result + .custom_id + .strip_prefix("req_hash_") + .unwrap_or(&result.custom_id) + .to_string(); + + if let Some(response_body) = result.response { + if response_body.status_code == 200 { + let response_json = serde_json::to_string(&response_body.body)?; + updates.push((request_hash, response_json, batch_id.clone())); + success_count += 1; + } else { + log::error!( + "Batch request {} failed with status {}", + request_hash, + response_body.status_code + ); + let error_json = serde_json::json!({ + "error": { + "type": "http_error", + "status_code": response_body.status_code + } + }) + .to_string(); + updates.push((request_hash, error_json, batch_id.clone())); + error_count += 1; + } + } else if let Some(error) = result.error { + log::error!( + "Batch request {} failed: {}: {}", + request_hash, + error.code, + error.message + ); + let error_json = serde_json::json!({ + "error": { + "type": error.code, + "message": error.message + } + }) + .to_string(); + updates.push((request_hash, error_json, batch_id.clone())); + error_count += 1; + } + } + + let connection = self.connection.lock().unwrap(); + connection.with_savepoint("batch_import", || { + let q = sql!( + INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id) + VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?) + ); + let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?; + for (request_hash, response_json, batch_id) in &updates { + exec(( + request_hash.as_str(), + request_hash.as_str(), + response_json.as_str(), + batch_id.as_str(), + ))?; + } + Ok(()) + })?; + + log::info!( + "Imported batch {}: {} successful, {} errors", + batch_id, + success_count, + error_count + ); + } + + Ok(()) + } + + async fn download_finished_batches(&self) -> Result<()> { + let batch_ids: Vec = { + let connection = self.connection.lock().unwrap(); + let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL); + connection.select(q)?()? + }; + + for batch_id in &batch_ids { + let batch_status = batches::retrieve_batch( + self.http_client.as_ref(), + OPEN_AI_API_URL, + &self.api_key, + batch_id, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + log::info!("Batch {} status: {}", batch_id, batch_status.status); + + if batch_status.status == "completed" { + let output_file_id = match batch_status.output_file_id { + Some(id) => id, + None => { + log::warn!("Batch {} completed but has no output file", batch_id); + continue; + } + }; + + let results_content = batches::download_file( + self.http_client.as_ref(), + OPEN_AI_API_URL, + &self.api_key, + &output_file_id, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + let results = batches::parse_batch_output(&results_content) + .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?; + + let mut updates: Vec<(String, String)> = Vec::new(); + let mut success_count = 0; + + for result in results { + let request_hash = result + .custom_id + .strip_prefix("req_hash_") + .unwrap_or(&result.custom_id) + .to_string(); + + if let Some(response_body) = result.response { + if response_body.status_code == 200 { + let response_json = serde_json::to_string(&response_body.body)?; + updates.push((response_json, request_hash)); + success_count += 1; + } else { + log::error!( + "Batch request {} failed with status {}", + request_hash, + response_body.status_code + ); + let error_json = serde_json::json!({ + "error": { + "type": "http_error", + "status_code": response_body.status_code + } + }) + .to_string(); + updates.push((error_json, request_hash)); + } + } else if let Some(error) = result.error { + log::error!( + "Batch request {} failed: {}: {}", + request_hash, + error.code, + error.message + ); + let error_json = serde_json::json!({ + "error": { + "type": error.code, + "message": error.message + } + }) + .to_string(); + updates.push((error_json, request_hash)); + } + } + + let connection = self.connection.lock().unwrap(); + connection.with_savepoint("batch_download", || { + let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?); + let mut exec = connection.exec_bound::<(&str, &str)>(q)?; + for (response_json, request_hash) in &updates { + exec((response_json.as_str(), request_hash.as_str()))?; + } + Ok(()) + })?; + log::info!("Downloaded {} successful requests", success_count); + } + } + + Ok(()) + } + + async fn upload_pending_requests(&self) -> Result> { + const BATCH_CHUNK_SIZE: i32 = 16_000; + let mut all_batch_ids = Vec::new(); + let mut total_uploaded = 0; + + loop { + let rows: Vec<(String, String)> = { + let connection = self.connection.lock().unwrap(); + let q = sql!( + SELECT request_hash, request FROM openai_cache + WHERE batch_id IS NULL AND response IS NULL + LIMIT ? + ); + connection.select_bound(q)?(BATCH_CHUNK_SIZE)? + }; + + if rows.is_empty() { + break; + } + + let request_hashes: Vec = rows.iter().map(|(hash, _)| hash.clone()).collect(); + + let mut jsonl_content = String::new(); + for (hash, request_str) in &rows { + let serializable_request: SerializableRequest = + serde_json::from_str(request_str).unwrap(); + + let messages: Vec = serializable_request + .messages + .into_iter() + .map(|msg| match msg.role.as_str() { + "user" => RequestMessage::User { + content: MessageContent::Plain(msg.content), + }, + "assistant" => RequestMessage::Assistant { + content: Some(MessageContent::Plain(msg.content)), + tool_calls: Vec::new(), + }, + "system" => RequestMessage::System { + content: MessageContent::Plain(msg.content), + }, + _ => RequestMessage::User { + content: MessageContent::Plain(msg.content), + }, + }) + .collect(); + + let request = OpenAiRequest { + model: serializable_request.model, + messages, + stream: false, + max_completion_tokens: Some(serializable_request.max_tokens), + stop: Vec::new(), + temperature: None, + tool_choice: None, + parallel_tool_calls: None, + tools: Vec::new(), + prompt_cache_key: None, + reasoning_effort: None, + }; + + let custom_id = format!("req_hash_{}", hash); + let batch_item = batches::BatchRequestItem::new(custom_id, request); + let line = batch_item + .to_jsonl_line() + .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?; + jsonl_content.push_str(&line); + jsonl_content.push('\n'); + } + + let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp()); + let file_obj = batches::upload_batch_file( + self.http_client.as_ref(), + OPEN_AI_API_URL, + &self.api_key, + &filename, + jsonl_content.into_bytes(), + ) + .await + .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?; + + let batch = batches::create_batch( + self.http_client.as_ref(), + OPEN_AI_API_URL, + &self.api_key, + batches::CreateBatchRequest::new(file_obj.id), + ) + .await + .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?; + + { + let connection = self.connection.lock().unwrap(); + connection.with_savepoint("batch_upload", || { + let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?); + let mut exec = connection.exec_bound::<(&str, &str)>(q)?; + for hash in &request_hashes { + exec((batch.id.as_str(), hash.as_str()))?; + } + Ok(()) + })?; + } + + let batch_len = rows.len(); + total_uploaded += batch_len; + log::info!( + "Uploaded batch {} with {} requests ({} total)", + batch.id, + batch_len, + total_uploaded + ); + + all_batch_ids.push(batch.id); + } + + if !all_batch_ids.is_empty() { + log::info!( + "Finished uploading {} batches with {} total requests", + all_batch_ids.len(), + total_uploaded + ); + } + + Ok(all_batch_ids) + } + + fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String { + let mut hasher = std::hash::DefaultHasher::new(); + "openai".hash(&mut hasher); + model.hash(&mut hasher); + max_tokens.hash(&mut hasher); + for msg in messages { + message_content_to_string(msg).hash(&mut hasher); + } + let request_hash = hasher.finish(); + format!("{request_hash:016x}") + } +} + +fn message_role_to_string(msg: &RequestMessage) -> String { + match msg { + RequestMessage::User { .. } => "user".to_string(), + RequestMessage::Assistant { .. } => "assistant".to_string(), + RequestMessage::System { .. } => "system".to_string(), + RequestMessage::Tool { .. } => "tool".to_string(), + } +} + +fn message_content_to_string(msg: &RequestMessage) -> String { + match msg { + RequestMessage::User { content } => content_to_string(content), + RequestMessage::Assistant { content, .. } => { + content.as_ref().map(content_to_string).unwrap_or_default() + } + RequestMessage::System { content } => content_to_string(content), + RequestMessage::Tool { content, .. } => content_to_string(content), + } +} + +fn content_to_string(content: &MessageContent) -> String { + match content { + MessageContent::Plain(text) => text.clone(), + MessageContent::Multipart(parts) => parts + .iter() + .filter_map(|part| match part { + open_ai::MessagePart::Text { text } => Some(text.clone()), + _ => None, + }) + .collect::>() + .join("\n"), + } +} + +pub enum OpenAiClient { + Plain(PlainOpenAiClient), + Batch(BatchingOpenAiClient), + #[allow(dead_code)] + Dummy, +} + +impl OpenAiClient { + pub fn plain() -> Result { + Ok(Self::Plain(PlainOpenAiClient::new()?)) + } + + pub fn batch(cache_path: &Path) -> Result { + Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?)) + } + + #[allow(dead_code)] + pub fn dummy() -> Self { + Self::Dummy + } + + pub async fn generate( + &self, + model: &str, + max_tokens: u64, + messages: Vec, + ) -> Result> { + match self { + OpenAiClient::Plain(plain_client) => plain_client + .generate(model, max_tokens, messages) + .await + .map(Some), + OpenAiClient::Batch(batching_client) => { + batching_client.generate(model, max_tokens, messages).await + } + OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"), + } + } + + pub async fn sync_batches(&self) -> Result<()> { + match self { + OpenAiClient::Plain(_) => Ok(()), + OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await, + OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"), + } + } + + pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> { + match self { + OpenAiClient::Plain(_) => { + anyhow::bail!("Import batches is only supported with batching client") + } + OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await, + OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"), + } + } +} diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 79669df01078269ca28d4ed9a2a17cfc2f0edfb1..c0b6af3f71de6ee3134e449c7db50b129c3b221b 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -1,10 +1,11 @@ use crate::{ - FormatPromptArgs, PredictArgs, PredictionProvider, + FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend, anthropic_client::AnthropicClient, example::{Example, ExamplePrediction, ExamplePrompt}, format_prompt::{TeacherPrompt, run_format_prompt}, headless::EpAppState, load_project::run_load_project, + openai_client::OpenAiClient, paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, progress::{ExampleProgress, InfoStyle, Step}, retrieve_context::run_context_retrieval, @@ -20,9 +21,9 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, }, }; -use zeta_prompt::ZetaVersion; static ANTHROPIC_CLIENT: OnceLock = OnceLock::new(); +static OPENAI_CLIENT: OnceLock = OnceLock::new(); pub async fn run_prediction( example: &mut Example, @@ -53,7 +54,7 @@ pub async fn run_prediction( run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?; - if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) = + if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) = provider { let _step_progress = example_progress.start(Step::Predict); @@ -68,7 +69,7 @@ pub async fn run_prediction( .await?; let batched = matches!(provider, PredictionProvider::Teacher(..)); - return predict_anthropic(example, repetition_count, version, batched).await; + return predict_teacher(example, backend, batched).await; } run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; @@ -148,6 +149,7 @@ pub async fn run_prediction( updated_example.prompt.get_or_insert(ExamplePrompt { input: prompt, expected_output: String::new(), + rejected_output: None, provider, }); } @@ -255,13 +257,23 @@ pub async fn run_prediction( Ok(()) } +async fn predict_teacher( + example: &mut Example, + backend: TeacherBackend, + batched: bool, +) -> anyhow::Result<()> { + match backend { + TeacherBackend::Sonnet45 => predict_anthropic(example, backend, batched).await, + TeacherBackend::Gpt52 => predict_openai(example, backend, batched).await, + } +} + async fn predict_anthropic( example: &mut Example, - _repetition_count: usize, - version: ZetaVersion, + backend: TeacherBackend, batched: bool, ) -> anyhow::Result<()> { - let llm_model_name = "claude-sonnet-4-5"; + let llm_model_name = backend.model_name(); let max_tokens = 16384; let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { let client = if batched { @@ -300,16 +312,83 @@ async fn predict_anthropic( .collect::>() .join("\n"); - let actual_patch = TeacherPrompt::parse(&example, &actual_output)?; + let actual_patch = TeacherPrompt::parse(example, &actual_output)?; let prediction = ExamplePrediction { actual_patch: Some(actual_patch), actual_output, error: None, provider: if batched { - PredictionProvider::Teacher(version) + PredictionProvider::Teacher(backend) } else { - PredictionProvider::TeacherNonBatching(version) + PredictionProvider::TeacherNonBatching(backend) + }, + }; + + example.predictions.push(prediction); + Ok(()) +} + +async fn predict_openai( + example: &mut Example, + backend: TeacherBackend, + batched: bool, +) -> anyhow::Result<()> { + let llm_model_name = backend.model_name(); + let max_tokens = 16384; + let llm_client = OPENAI_CLIENT.get_or_init(|| { + let client = if batched { + OpenAiClient::batch(&crate::paths::LLM_CACHE_DB) + } else { + OpenAiClient::plain() + }; + client.expect("Failed to create OpenAI client") + }); + + let prompt = example.prompt.as_ref().context("Prompt is required")?; + + let messages = vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt.input.clone()), + }]; + + let Some(response) = llm_client + .generate(llm_model_name, max_tokens, messages) + .await? + else { + // Request stashed for batched processing + return Ok(()); + }; + + let actual_output = response + .choices + .into_iter() + .filter_map(|choice| match choice.message { + open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c { + open_ai::MessageContent::Plain(text) => text, + open_ai::MessageContent::Multipart(parts) => parts + .into_iter() + .filter_map(|p| match p { + open_ai::MessagePart::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join(""), + }), + _ => None, + }) + .collect::>() + .join("\n"); + + let actual_patch = TeacherPrompt::parse(example, &actual_output)?; + + let prediction = ExamplePrediction { + actual_patch: Some(actual_patch), + actual_output, + error: None, + provider: if batched { + PredictionProvider::Teacher(backend) + } else { + PredictionProvider::TeacherNonBatching(backend) }, }; @@ -319,16 +398,28 @@ async fn predict_anthropic( pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> { match provider { - Some(PredictionProvider::Teacher(..)) => { - let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { - AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) - .expect("Failed to create Anthropic client") - }); - llm_client - .sync_batches() - .await - .context("Failed to sync batches")?; - } + Some(PredictionProvider::Teacher(backend)) => match backend { + TeacherBackend::Sonnet45 => { + let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { + AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) + .expect("Failed to create Anthropic client") + }); + llm_client + .sync_batches() + .await + .context("Failed to sync Anthropic batches")?; + } + TeacherBackend::Gpt52 => { + let llm_client = OPENAI_CLIENT.get_or_init(|| { + OpenAiClient::batch(&crate::paths::LLM_CACHE_DB) + .expect("Failed to create OpenAI client") + }); + llm_client + .sync_batches() + .await + .context("Failed to sync OpenAI batches")?; + } + }, _ => (), }; Ok(()) diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs index f5005e08ae9db7b9c9b4d650b46af79f1223073a..59304bed825aabf37df48ead43d8d52525282946 100644 --- a/crates/edit_prediction_cli/src/qa.rs +++ b/crates/edit_prediction_cli/src/qa.rs @@ -1,22 +1,20 @@ //! Quality assessment of predictions using LLM-as-a-judge. //! -//! This module uses the Anthropic Batch API to evaluate prediction quality. -//! Caching is handled by the underlying AnthropicClient. +//! This module uses LLM Batch APIs to evaluate prediction quality. +//! Caching is handled by the underlying client. +use crate::BatchProvider; use crate::anthropic_client::AnthropicClient; use crate::example::Example; use crate::format_prompt::extract_cursor_excerpt_from_example; +use crate::openai_client::OpenAiClient; use crate::paths::LLM_CACHE_DB; use crate::word_diff::unified_to_word_diff; -use anthropic::{Message, RequestContent, Role}; use anyhow::Result; use serde::{Deserialize, Serialize}; use std::io::{BufWriter, Write}; use std::path::PathBuf; -/// Model to use for QA evaluation. -const MODEL: &str = "claude-sonnet-4-5"; - const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md"); /// Arguments for the QA command. @@ -29,6 +27,17 @@ pub struct QaArgs { /// Wait for batch to complete (polls every 30s) #[clap(long)] pub wait: bool, + + /// Which LLM provider to use (anthropic or openai) + #[clap(long, default_value = "anthropic")] + pub backend: BatchProvider, +} + +fn model_for_backend(backend: BatchProvider) -> &'static str { + match backend { + BatchProvider::Anthropic => "claude-sonnet-4-5", + BatchProvider::Openai => "gpt-5.2", + } } /// Result of QA evaluation for a single prediction. @@ -147,19 +156,101 @@ fn parse_response(response_text: &str) -> QaResult { } } +enum QaClient { + Anthropic(AnthropicClient), + OpenAi(OpenAiClient), +} + +impl QaClient { + async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result> { + match self { + QaClient::Anthropic(client) => { + let messages = vec![anthropic::Message { + role: anthropic::Role::User, + content: vec![anthropic::RequestContent::Text { + text: prompt.to_string(), + cache_control: None, + }], + }]; + let response = client.generate(model, max_tokens, messages).await?; + Ok(response.map(|r| { + r.content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") + })) + } + QaClient::OpenAi(client) => { + let messages = vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt.to_string()), + }]; + let response = client.generate(model, max_tokens, messages).await?; + Ok(response.map(|r| { + r.choices + .into_iter() + .filter_map(|choice| match choice.message { + open_ai::RequestMessage::Assistant { content, .. } => { + content.map(|c| match c { + open_ai::MessageContent::Plain(text) => text, + open_ai::MessageContent::Multipart(parts) => parts + .into_iter() + .filter_map(|p| match p { + open_ai::MessagePart::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join(""), + }) + } + _ => None, + }) + .collect::>() + .join("") + })) + } + } + } + + async fn sync_batches(&self) -> Result<()> { + match self { + QaClient::Anthropic(client) => client.sync_batches().await, + QaClient::OpenAi(client) => client.sync_batches().await, + } + } +} + /// Run the QA evaluation on a set of examples. pub async fn run_qa( examples: &mut [Example], args: &QaArgs, output_path: Option<&PathBuf>, ) -> Result<()> { - let client = if args.no_batch { - AnthropicClient::plain()? - } else { - AnthropicClient::batch(&LLM_CACHE_DB)? + let model = model_for_backend(args.backend); + let client = match args.backend { + BatchProvider::Anthropic => { + if args.no_batch { + QaClient::Anthropic(AnthropicClient::plain()?) + } else { + QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?) + } + } + BatchProvider::Openai => { + if args.no_batch { + QaClient::OpenAi(OpenAiClient::plain()?) + } else { + QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?) + } + } }; - eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch); + eprintln!( + "Using model: {}, backend: {:?}, batching: {}", + model, args.backend, !args.no_batch + ); // First pass: send requests (client handles caching internally) let mut prompts: Vec<(usize, String)> = Vec::new(); @@ -187,54 +278,16 @@ pub async fn run_qa( for (i, (idx, prompt)) in prompts.iter().enumerate() { eprint!("\rProcessing {}/{}", i + 1, prompts.len()); - let messages = vec![Message { - role: Role::User, - content: vec![RequestContent::Text { - text: prompt.clone(), - cache_control: None, - }], - }]; - - let response = client.generate(MODEL, 1024, messages).await?; - let result = response.map(|r| { - let text = r - .content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join(""); - parse_response(&text) - }); + let response = client.generate(model, 1024, prompt).await?; + let result = response.map(|text| parse_response(&text)); results.push((*idx, result)); } eprintln!(); } else { // Queue all for batching for (idx, prompt) in &prompts { - let messages = vec![Message { - role: Role::User, - content: vec![RequestContent::Text { - text: prompt.clone(), - cache_control: None, - }], - }]; - - let response = client.generate(MODEL, 1024, messages).await?; - let result = response.map(|r| { - let text = r - .content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join(""); - parse_response(&text) - }); + let response = client.generate(model, 1024, prompt).await?; + let result = response.map(|text| parse_response(&text)); results.push((*idx, result)); } @@ -251,27 +304,8 @@ pub async fn run_qa( let mut all_done = true; for (result_idx, (idx, prompt)) in prompts.iter().enumerate() { if results[result_idx].1.is_none() { - let messages = vec![Message { - role: Role::User, - content: vec![RequestContent::Text { - text: prompt.clone(), - cache_control: None, - }], - }]; - - let response = client.generate(MODEL, 1024, messages).await?; - if let Some(r) = response { - let text = r - .content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => { - Some(text.as_str()) - } - _ => None, - }) - .collect::>() - .join(""); + let response = client.generate(model, 1024, prompt).await?; + if let Some(text) = response { results[result_idx] = (*idx, Some(parse_response(&text))); } else { all_done = false; diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index a27205e131b19bfca7f3cedce6a1f01028b863bb..2f8de97d8cdf0126314a03ad47c52f2815f41639 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -4,20 +4,18 @@ //! predictions that need improvement (based on reverts_edits or low confidence), //! and uses an LLM to generate improved predictions. +use crate::BatchProvider; use crate::PredictionProvider; use crate::anthropic_client::AnthropicClient; use crate::example::{Example, ExamplePrediction}; use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example}; +use crate::openai_client::OpenAiClient; use crate::paths::LLM_CACHE_DB; use crate::word_diff::unified_to_word_diff; -use anthropic::{Message, RequestContent, Role}; use anyhow::Result; use std::io::{BufWriter, Write}; use std::path::PathBuf; -/// Model to use for repair. -const MODEL: &str = "claude-sonnet-4-5"; - const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md"); /// Arguments for the repair command. @@ -34,6 +32,17 @@ pub struct RepairArgs { /// Confidence threshold: repair predictions with confidence <= this value (1-5) #[clap(long, default_value = "2")] pub confidence_threshold: u8, + + /// Which LLM provider to use (anthropic or openai) + #[clap(long, default_value = "anthropic")] + pub backend: BatchProvider, +} + +fn model_for_backend(backend: BatchProvider) -> &'static str { + match backend { + BatchProvider::Anthropic => "claude-sonnet-4-5", + BatchProvider::Openai => "gpt-5.2", + } } /// Build the repair prompt for an example that needs improvement. @@ -127,21 +136,100 @@ fn parse_repair_response(example: &Example, response_text: &str) -> Result Result> { + match self { + RepairClient::Anthropic(client) => { + let messages = vec![anthropic::Message { + role: anthropic::Role::User, + content: vec![anthropic::RequestContent::Text { + text: prompt.to_string(), + cache_control: None, + }], + }]; + let response = client.generate(model, max_tokens, messages).await?; + Ok(response.map(|r| { + r.content + .iter() + .filter_map(|c| match c { + anthropic::ResponseContent::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") + })) + } + RepairClient::OpenAi(client) => { + let messages = vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt.to_string()), + }]; + let response = client.generate(model, max_tokens, messages).await?; + Ok(response.map(|r| { + r.choices + .into_iter() + .filter_map(|choice| match choice.message { + open_ai::RequestMessage::Assistant { content, .. } => { + content.map(|c| match c { + open_ai::MessageContent::Plain(text) => text, + open_ai::MessageContent::Multipart(parts) => parts + .into_iter() + .filter_map(|p| match p { + open_ai::MessagePart::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join(""), + }) + } + _ => None, + }) + .collect::>() + .join("") + })) + } + } + } + + async fn sync_batches(&self) -> Result<()> { + match self { + RepairClient::Anthropic(client) => client.sync_batches().await, + RepairClient::OpenAi(client) => client.sync_batches().await, + } + } +} + /// Run the repair process on a set of examples. pub async fn run_repair( examples: &mut [Example], args: &RepairArgs, output_path: Option<&PathBuf>, ) -> Result<()> { - let client = if args.no_batch { - AnthropicClient::plain()? - } else { - AnthropicClient::batch(&LLM_CACHE_DB)? + let model = model_for_backend(args.backend); + let client = match args.backend { + BatchProvider::Anthropic => { + if args.no_batch { + RepairClient::Anthropic(AnthropicClient::plain()?) + } else { + RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?) + } + } + BatchProvider::Openai => { + if args.no_batch { + RepairClient::OpenAi(OpenAiClient::plain()?) + } else { + RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?) + } + } }; eprintln!( - "Using model: {}, batching: {}, confidence_threshold: {}", - MODEL, !args.no_batch, args.confidence_threshold + "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}", + model, args.backend, !args.no_batch, args.confidence_threshold ); // First pass: identify examples that need repair and build prompts @@ -185,51 +273,15 @@ pub async fn run_repair( for (i, (idx, prompt)) in repair_items.iter().enumerate() { eprint!("\rProcessing {}/{}", i + 1, repair_items.len()); - let messages = vec![Message { - role: Role::User, - content: vec![RequestContent::Text { - text: prompt.clone(), - cache_control: None, - }], - }]; - - let response = client.generate(MODEL, 16384, messages).await?; - let result = response.map(|r| { - r.content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join("") - }); - results.push((*idx, result)); + let response = client.generate(model, 16384, prompt).await?; + results.push((*idx, response)); } eprintln!(); } else { // Queue all for batching for (idx, prompt) in &repair_items { - let messages = vec![Message { - role: Role::User, - content: vec![RequestContent::Text { - text: prompt.clone(), - cache_control: None, - }], - }]; - - let response = client.generate(MODEL, 16384, messages).await?; - let result = response.map(|r| { - r.content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect::>() - .join("") - }); - results.push((*idx, result)); + let response = client.generate(model, 16384, prompt).await?; + results.push((*idx, response)); } // Sync batches (upload pending, download finished) @@ -245,27 +297,8 @@ pub async fn run_repair( let mut all_done = true; for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() { if results[result_idx].1.is_none() { - let messages = vec![Message { - role: Role::User, - content: vec![RequestContent::Text { - text: prompt.clone(), - cache_control: None, - }], - }]; - - let response = client.generate(MODEL, 16384, messages).await?; - if let Some(r) = response { - let text = r - .content - .iter() - .filter_map(|c| match c { - anthropic::ResponseContent::Text { text } => { - Some(text.as_str()) - } - _ => None, - }) - .collect::>() - .join(""); + let response = client.generate(model, 16384, prompt).await?; + if let Some(text) = response { results[result_idx] = (*idx, Some(text)); } else { all_done = false; diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 037ca14437cd13a6fc4bfe76dafb113c6a9f1482..3de3a4dc3fcb8c9519f4c67be7cead75401f6281 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -19,6 +19,7 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +rand.workspace = true schemars = { workspace = true, optional = true } log.workspace = true serde.workspace = true diff --git a/crates/open_ai/src/batches.rs b/crates/open_ai/src/batches.rs new file mode 100644 index 0000000000000000000000000000000000000000..a93fef675ddcdc3667b4f8448cacc54bada8272e --- /dev/null +++ b/crates/open_ai/src/batches.rs @@ -0,0 +1,332 @@ +use anyhow::Result; +use futures::AsyncReadExt; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; + +use crate::{Request, RequestError, Response}; + +/// A single request within a batch +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchRequestItem { + pub custom_id: String, + pub method: String, + pub url: String, + pub body: Request, +} + +impl BatchRequestItem { + pub fn new(custom_id: String, request: Request) -> Self { + Self { + custom_id, + method: "POST".to_string(), + url: "/v1/chat/completions".to_string(), + body: request, + } + } + + pub fn to_jsonl_line(&self) -> Result { + serde_json::to_string(self) + } +} + +/// Request to create a batch +#[derive(Debug, Serialize)] +pub struct CreateBatchRequest { + pub input_file_id: String, + pub endpoint: String, + pub completion_window: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +impl CreateBatchRequest { + pub fn new(input_file_id: String) -> Self { + Self { + input_file_id, + endpoint: "/v1/chat/completions".to_string(), + completion_window: "24h".to_string(), + metadata: None, + } + } +} + +/// Response from batch creation or retrieval +#[derive(Debug, Serialize, Deserialize)] +pub struct Batch { + pub id: String, + pub object: String, + pub endpoint: String, + pub input_file_id: String, + pub completion_window: String, + pub status: String, + pub output_file_id: Option, + pub error_file_id: Option, + pub created_at: u64, + #[serde(default)] + pub in_progress_at: Option, + #[serde(default)] + pub expires_at: Option, + #[serde(default)] + pub finalizing_at: Option, + #[serde(default)] + pub completed_at: Option, + #[serde(default)] + pub failed_at: Option, + #[serde(default)] + pub expired_at: Option, + #[serde(default)] + pub cancelling_at: Option, + #[serde(default)] + pub cancelled_at: Option, + #[serde(default)] + pub request_counts: Option, + #[serde(default)] + pub metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct BatchRequestCounts { + pub total: u64, + pub completed: u64, + pub failed: u64, +} + +/// Response from file upload +#[derive(Debug, Serialize, Deserialize)] +pub struct FileObject { + pub id: String, + pub object: String, + pub bytes: u64, + pub created_at: u64, + pub filename: String, + pub purpose: String, +} + +/// Individual result from batch output +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchOutputItem { + pub id: String, + pub custom_id: String, + pub response: Option, + pub error: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchResponseBody { + pub status_code: u16, + pub body: Response, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchError { + pub code: String, + pub message: String, +} + +/// Upload a JSONL file for batch processing +pub async fn upload_batch_file( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + filename: &str, + content: Vec, +) -> Result { + let uri = format!("{api_url}/files"); + + let boundary = format!("----WebKitFormBoundary{:x}", rand::random::()); + + let mut body = Vec::new(); + body.extend_from_slice(format!("--{boundary}\r\n").as_bytes()); + body.extend_from_slice(b"Content-Disposition: form-data; name=\"purpose\"\r\n\r\n"); + body.extend_from_slice(b"batch\r\n"); + body.extend_from_slice(format!("--{boundary}\r\n").as_bytes()); + body.extend_from_slice( + format!("Content-Disposition: form-data; name=\"file\"; filename=\"{filename}\"\r\n") + .as_bytes(), + ); + body.extend_from_slice(b"Content-Type: application/jsonl\r\n\r\n"); + body.extend_from_slice(&content); + body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes()); + + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Authorization", format!("Bearer {}", api_key.trim())) + .header( + "Content-Type", + format!("multipart/form-data; boundary={boundary}"), + ) + .body(AsyncBody::from(body)) + .map_err(|e| RequestError::Other(e.into()))?; + + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into())) + } else { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Err(RequestError::HttpResponseError { + provider: "openai".to_owned(), + status_code: response.status(), + body, + headers: response.headers().clone(), + }) + } +} + +/// Create a batch from an uploaded file +pub async fn create_batch( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: CreateBatchRequest, +) -> Result { + let uri = format!("{api_url}/batches"); + + let serialized = serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?; + + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Authorization", format!("Bearer {}", api_key.trim())) + .header("Content-Type", "application/json") + .body(AsyncBody::from(serialized)) + .map_err(|e| RequestError::Other(e.into()))?; + + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into())) + } else { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Err(RequestError::HttpResponseError { + provider: "openai".to_owned(), + status_code: response.status(), + body, + headers: response.headers().clone(), + }) + } +} + +/// Retrieve batch status +pub async fn retrieve_batch( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + batch_id: &str, +) -> Result { + let uri = format!("{api_url}/batches/{batch_id}"); + + let request = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Authorization", format!("Bearer {}", api_key.trim())) + .body(AsyncBody::default()) + .map_err(|e| RequestError::Other(e.into()))?; + + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into())) + } else { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Err(RequestError::HttpResponseError { + provider: "openai".to_owned(), + status_code: response.status(), + body, + headers: response.headers().clone(), + }) + } +} + +/// Download file content (for batch results) +pub async fn download_file( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + file_id: &str, +) -> Result { + let uri = format!("{api_url}/files/{file_id}/content"); + + let request = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Authorization", format!("Bearer {}", api_key.trim())) + .body(AsyncBody::default()) + .map_err(|e| RequestError::Other(e.into()))?; + + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Ok(body) + } else { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Err(RequestError::HttpResponseError { + provider: "openai".to_owned(), + status_code: response.status(), + body, + headers: response.headers().clone(), + }) + } +} + +/// Parse batch output JSONL into individual results +pub fn parse_batch_output(content: &str) -> Result, serde_json::Error> { + content + .lines() + .filter(|line| !line.trim().is_empty()) + .map(|line| serde_json::from_str(line)) + .collect() +} diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index e05865e079cbf71add94b790b1659b09a8e8fa22..073217e777c39f374560c208923848ea88e11a6a 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,3 +1,4 @@ +pub mod batches; pub mod responses; use anyhow::{Context as _, Result, anyhow}; @@ -529,6 +530,52 @@ pub struct ResponseStreamEvent { pub usage: Option, } +pub async fn non_streaming_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key.trim())); + + let request = request_builder + .body(AsyncBody::from( + serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?, + )) + .map_err(|e| RequestError::Other(e.into()))?; + + let mut response = client.send(request).await?; + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into())) + } else { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Err(RequestError::HttpResponseError { + provider: "openai".to_owned(), + status_code: response.status(), + body, + headers: response.headers().clone(), + }) + } +} + pub async fn stream_completion( client: &dyn HttpClient, provider_name: &str,