Detailed changes
@@ -5338,6 +5338,7 @@ dependencies = [
"libc",
"log",
"node_runtime",
+ "open_ai",
"paths",
"pretty_assertions",
"project",
@@ -11140,6 +11141,7 @@ dependencies = [
"futures 0.3.31",
"http_client",
"log",
+ "rand 0.9.2",
"schemars",
"serde",
"serde_json",
@@ -37,6 +37,7 @@ languages = { workspace = true, features = ["load-grammars"] }
libc.workspace = true
log.workspace = true
node_runtime.workspace = true
+open_ai.workspace = true
paths.workspace = true
project.workspace = true
@@ -1,17 +1,27 @@
use anyhow::Result;
-use std::mem;
-use crate::example::Example;
+use crate::{PredictionProvider, example::Example};
pub async fn run_distill(example: &mut Example) -> Result<()> {
- let predictions = mem::take(&mut example.predictions)
+ let has_repair = example
+ .predictions
+ .iter()
+ .find(|p| p.provider == PredictionProvider::Repair);
+ let predictions = if let Some(has_repair) = has_repair {
+ vec![has_repair]
+ } else {
+ example.predictions.iter().collect()
+ };
+
+ let expected_patches = predictions
.into_iter()
- .filter_map(|p| p.actual_patch)
+ .filter_map(|p| p.actual_patch.clone())
.collect();
- example.spec.expected_patches = predictions;
+ example.spec.expected_patches = expected_patches;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
+ example.qa = Vec::new();
Ok(())
}
@@ -73,6 +73,7 @@ pub struct ExamplePromptInputs {
pub struct ExamplePrompt {
pub input: String,
pub expected_output: String,
+ pub rejected_output: Option<String>, // For DPO
pub provider: PredictionProvider,
}
@@ -48,27 +48,31 @@ pub async fn run_format_prompt(
let snapshot = cx.background_spawn(snapshot_fut).await;
match args.provider {
- PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
+ PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
step_progress.set_substatus("formatting teacher prompt");
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
cursor_point,
&snapshot,
- edit_prediction::zeta2::max_editable_tokens(version),
+ edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()),
edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
);
let editable_range = editable_range.to_offset(&snapshot);
let context_range = context_range.to_offset(&snapshot);
let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
+ let expected_output = example
+ .spec
+ .expected_patches
+ .first()
+ .cloned()
+ .unwrap_or_default();
+ let rejected_output = example.spec.rejected_patch.clone();
+
example.prompt = Some(ExamplePrompt {
input: prompt,
- expected_output: example
- .spec
- .expected_patches
- .first()
- .cloned()
- .unwrap_or_default(),
+ expected_output,
+ rejected_output,
provider: args.provider,
});
}
@@ -107,9 +111,16 @@ pub async fn run_format_prompt(
.clone(),
version,
)?;
+ let rejected_output = example
+ .spec
+ .rejected_patch
+ .as_ref()
+ .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok());
+
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,
+ rejected_output,
provider: args.provider,
});
}
@@ -7,6 +7,7 @@ mod git;
mod headless;
mod load_project;
mod metrics;
+mod openai_client;
mod parse_output;
mod paths;
mod predict;
@@ -241,14 +242,51 @@ struct EvalArgs {
summary_json: Option<PathBuf>,
}
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+pub enum TeacherBackend {
+ Sonnet45,
+ Gpt52,
+}
+
+impl std::fmt::Display for TeacherBackend {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
+ TeacherBackend::Gpt52 => write!(f, "gpt52"),
+ }
+ }
+}
+
+impl std::str::FromStr for TeacherBackend {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s.to_lowercase().as_str() {
+ "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
+ "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
+ "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
+ _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
+ }
+ }
+}
+
+impl TeacherBackend {
+ pub fn model_name(&self) -> &'static str {
+ match self {
+ TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
+ TeacherBackend::Gpt52 => "gpt-5.2",
+ }
+ }
+}
+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum PredictionProvider {
Sweep,
Mercury,
Zeta1,
Zeta2(ZetaVersion),
- Teacher(ZetaVersion),
- TeacherNonBatching(ZetaVersion),
+ Teacher(TeacherBackend),
+ TeacherNonBatching(TeacherBackend),
Repair,
}
@@ -265,9 +303,9 @@ impl std::fmt::Display for PredictionProvider {
PredictionProvider::Mercury => write!(f, "mercury"),
PredictionProvider::Zeta1 => write!(f, "zeta1"),
PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
- PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
- PredictionProvider::TeacherNonBatching(version) => {
- write!(f, "teacher-non-batching:{version}")
+ PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
+ PredictionProvider::TeacherNonBatching(backend) => {
+ write!(f, "teacher-non-batching:{backend}")
}
PredictionProvider::Repair => write!(f, "repair"),
}
@@ -277,28 +315,38 @@ impl std::fmt::Display for PredictionProvider {
impl std::str::FromStr for PredictionProvider {
type Err = anyhow::Error;
- fn from_str(mut s: &str) -> Result<Self, Self::Err> {
- let mut version = ZetaVersion::default();
- if let Some((first, second)) = s.split_once(':') {
- version = ZetaVersion::parse(second)?;
- s = first;
- }
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
- let s_lower = s.to_lowercase();
- match s_lower.as_str() {
+ let provider_lower = provider.to_lowercase();
+ match provider_lower.as_str() {
"sweep" => Ok(PredictionProvider::Sweep),
"mercury" => Ok(PredictionProvider::Mercury),
"zeta1" => Ok(PredictionProvider::Zeta1),
- "zeta2" => Ok(PredictionProvider::Zeta2(version)),
- "teacher" => Ok(PredictionProvider::Teacher(version)),
+ "zeta2" => {
+ let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
+ Ok(PredictionProvider::Zeta2(version))
+ }
+ "teacher" => {
+ let backend = arg
+ .map(|a| a.parse())
+ .transpose()?
+ .unwrap_or(TeacherBackend::Sonnet45);
+ Ok(PredictionProvider::Teacher(backend))
+ }
"teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
- Ok(PredictionProvider::TeacherNonBatching(version))
+ let backend = arg
+ .map(|a| a.parse())
+ .transpose()?
+ .unwrap_or(TeacherBackend::Sonnet45);
+ Ok(PredictionProvider::TeacherNonBatching(backend))
}
"repair" => Ok(PredictionProvider::Repair),
_ => {
anyhow::bail!(
- "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching, repair\n\
+ "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
+ For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
Available zeta versions:\n{}",
ZetaVersion::options_as_string()
)
@@ -347,9 +395,18 @@ struct SynthesizeArgs {
#[derive(Debug, Args, Clone)]
struct ImportBatchArgs {
- /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
+ /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
#[clap(long, required = true, num_args = 1..)]
batch_ids: Vec<String>,
+ /// Which provider's batches to import (anthropic or openai)
+ #[clap(long, default_value = "anthropic")]
+ provider: BatchProvider,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
+enum BatchProvider {
+ Anthropic,
+ Openai,
}
impl EpArgs {
@@ -537,11 +594,23 @@ fn main() {
match &command {
Command::ImportBatch(import_args) => {
smol::block_on(async {
- let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
- .expect("Failed to create Anthropic client");
- if let Err(e) = client.import_batches(&import_args.batch_ids).await {
- eprintln!("Error importing batches: {:?}", e);
- std::process::exit(1);
+ match import_args.provider {
+ BatchProvider::Anthropic => {
+ let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
+ .expect("Failed to create Anthropic client");
+ if let Err(e) = client.import_batches(&import_args.batch_ids).await {
+ eprintln!("Error importing Anthropic batches: {:?}", e);
+ std::process::exit(1);
+ }
+ }
+ BatchProvider::Openai => {
+ let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
+ .expect("Failed to create OpenAI client");
+ if let Err(e) = client.import_batches(&import_args.batch_ids).await {
+ eprintln!("Error importing OpenAI batches: {:?}", e);
+ std::process::exit(1);
+ }
+ }
}
println!(
"Successfully imported {} batch(es)",
@@ -0,0 +1,664 @@
+use anyhow::Result;
+use http_client::HttpClient;
+use indoc::indoc;
+use open_ai::{
+ MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
+ Response as OpenAiResponse, batches, non_streaming_completion,
+};
+use reqwest_client::ReqwestClient;
+use sqlez::bindable::Bind;
+use sqlez::bindable::StaticColumnCount;
+use sqlez_macros::sql;
+use std::hash::Hash;
+use std::hash::Hasher;
+use std::path::Path;
+use std::sync::{Arc, Mutex};
+
+pub struct PlainOpenAiClient {
+ pub http_client: Arc<dyn HttpClient>,
+ pub api_key: String,
+}
+
+impl PlainOpenAiClient {
+ pub fn new() -> Result<Self> {
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+ let api_key = std::env::var("OPENAI_API_KEY")
+ .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
+ Ok(Self {
+ http_client,
+ api_key,
+ })
+ }
+
+ pub async fn generate(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: Vec<RequestMessage>,
+ ) -> Result<OpenAiResponse> {
+ let request = OpenAiRequest {
+ model: model.to_string(),
+ messages,
+ stream: false,
+ max_completion_tokens: Some(max_tokens),
+ stop: Vec::new(),
+ temperature: None,
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: Vec::new(),
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ let response = non_streaming_completion(
+ self.http_client.as_ref(),
+ OPEN_AI_API_URL,
+ &self.api_key,
+ request,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ Ok(response)
+ }
+}
+
+pub struct BatchingOpenAiClient {
+ connection: Mutex<sqlez::connection::Connection>,
+ http_client: Arc<dyn HttpClient>,
+ api_key: String,
+}
+
+struct CacheRow {
+ request_hash: String,
+ request: Option<String>,
+ response: Option<String>,
+ batch_id: Option<String>,
+}
+
+impl StaticColumnCount for CacheRow {
+ fn column_count() -> usize {
+ 4
+ }
+}
+
+impl Bind for CacheRow {
+ fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
+ let next_index = statement.bind(&self.request_hash, start_index)?;
+ let next_index = statement.bind(&self.request, next_index)?;
+ let next_index = statement.bind(&self.response, next_index)?;
+ let next_index = statement.bind(&self.batch_id, next_index)?;
+ Ok(next_index)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableRequest {
+ model: String,
+ max_tokens: u64,
+ messages: Vec<SerializableMessage>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableMessage {
+ role: String,
+ content: String,
+}
+
+impl BatchingOpenAiClient {
+ fn new(cache_path: &Path) -> Result<Self> {
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+ let api_key = std::env::var("OPENAI_API_KEY")
+ .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
+
+ let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
+ let mut statement = sqlez::statement::Statement::prepare(
+ &connection,
+ indoc! {"
+ CREATE TABLE IF NOT EXISTS openai_cache (
+ request_hash TEXT PRIMARY KEY,
+ request TEXT,
+ response TEXT,
+ batch_id TEXT
+ );
+ "},
+ )?;
+ statement.exec()?;
+ drop(statement);
+
+ Ok(Self {
+ connection: Mutex::new(connection),
+ http_client,
+ api_key,
+ })
+ }
+
+ pub fn lookup(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: &[RequestMessage],
+ ) -> Result<Option<OpenAiResponse>> {
+ let request_hash_str = Self::request_hash(model, max_tokens, messages);
+ let connection = self.connection.lock().unwrap();
+ let response: Vec<String> = connection.select_bound(
+ &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
+ )?(request_hash_str.as_str())?;
+ Ok(response
+ .into_iter()
+ .next()
+ .and_then(|text| serde_json::from_str(&text).ok()))
+ }
+
+ pub fn mark_for_batch(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: &[RequestMessage],
+ ) -> Result<()> {
+ let request_hash = Self::request_hash(model, max_tokens, messages);
+
+ let serializable_messages: Vec<SerializableMessage> = messages
+ .iter()
+ .map(|msg| SerializableMessage {
+ role: message_role_to_string(msg),
+ content: message_content_to_string(msg),
+ })
+ .collect();
+
+ let serializable_request = SerializableRequest {
+ model: model.to_string(),
+ max_tokens,
+ messages: serializable_messages,
+ };
+
+ let request = Some(serde_json::to_string(&serializable_request)?);
+ let cache_row = CacheRow {
+ request_hash,
+ request,
+ response: None,
+ batch_id: None,
+ };
+ let connection = self.connection.lock().unwrap();
+ connection.exec_bound::<CacheRow>(sql!(
+ INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
+ cache_row,
+ )
+ }
+
+ async fn generate(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: Vec<RequestMessage>,
+ ) -> Result<Option<OpenAiResponse>> {
+ let response = self.lookup(model, max_tokens, &messages)?;
+ if let Some(response) = response {
+ return Ok(Some(response));
+ }
+
+ self.mark_for_batch(model, max_tokens, &messages)?;
+
+ Ok(None)
+ }
+
+ async fn sync_batches(&self) -> Result<()> {
+ let _batch_ids = self.upload_pending_requests().await?;
+ self.download_finished_batches().await
+ }
+
+ pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
+ for batch_id in batch_ids {
+ log::info!("Importing OpenAI batch {}", batch_id);
+
+ let batch_status = batches::retrieve_batch(
+ self.http_client.as_ref(),
+ OPEN_AI_API_URL,
+ &self.api_key,
+ batch_id,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
+
+ log::info!("Batch {} status: {}", batch_id, batch_status.status);
+
+ if batch_status.status != "completed" {
+ log::warn!(
+ "Batch {} is not completed (status: {}), skipping",
+ batch_id,
+ batch_status.status
+ );
+ continue;
+ }
+
+ let output_file_id = batch_status.output_file_id.ok_or_else(|| {
+ anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
+ })?;
+
+ let results_content = batches::download_file(
+ self.http_client.as_ref(),
+ OPEN_AI_API_URL,
+ &self.api_key,
+ &output_file_id,
+ )
+ .await
+ .map_err(|e| {
+ anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
+ })?;
+
+ let results = batches::parse_batch_output(&results_content)
+ .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
+
+ let mut updates: Vec<(String, String, String)> = Vec::new();
+ let mut success_count = 0;
+ let mut error_count = 0;
+
+ for result in results {
+ let request_hash = result
+ .custom_id
+ .strip_prefix("req_hash_")
+ .unwrap_or(&result.custom_id)
+ .to_string();
+
+ if let Some(response_body) = result.response {
+ if response_body.status_code == 200 {
+ let response_json = serde_json::to_string(&response_body.body)?;
+ updates.push((request_hash, response_json, batch_id.clone()));
+ success_count += 1;
+ } else {
+ log::error!(
+ "Batch request {} failed with status {}",
+ request_hash,
+ response_body.status_code
+ );
+ let error_json = serde_json::json!({
+ "error": {
+ "type": "http_error",
+ "status_code": response_body.status_code
+ }
+ })
+ .to_string();
+ updates.push((request_hash, error_json, batch_id.clone()));
+ error_count += 1;
+ }
+ } else if let Some(error) = result.error {
+ log::error!(
+ "Batch request {} failed: {}: {}",
+ request_hash,
+ error.code,
+ error.message
+ );
+ let error_json = serde_json::json!({
+ "error": {
+ "type": error.code,
+ "message": error.message
+ }
+ })
+ .to_string();
+ updates.push((request_hash, error_json, batch_id.clone()));
+ error_count += 1;
+ }
+ }
+
+ let connection = self.connection.lock().unwrap();
+ connection.with_savepoint("batch_import", || {
+ let q = sql!(
+ INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
+ VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
+ );
+ let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
+ for (request_hash, response_json, batch_id) in &updates {
+ exec((
+ request_hash.as_str(),
+ request_hash.as_str(),
+ response_json.as_str(),
+ batch_id.as_str(),
+ ))?;
+ }
+ Ok(())
+ })?;
+
+ log::info!(
+ "Imported batch {}: {} successful, {} errors",
+ batch_id,
+ success_count,
+ error_count
+ );
+ }
+
+ Ok(())
+ }
+
+ async fn download_finished_batches(&self) -> Result<()> {
+ let batch_ids: Vec<String> = {
+ let connection = self.connection.lock().unwrap();
+ let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
+ connection.select(q)?()?
+ };
+
+ for batch_id in &batch_ids {
+ let batch_status = batches::retrieve_batch(
+ self.http_client.as_ref(),
+ OPEN_AI_API_URL,
+ &self.api_key,
+ batch_id,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ log::info!("Batch {} status: {}", batch_id, batch_status.status);
+
+ if batch_status.status == "completed" {
+ let output_file_id = match batch_status.output_file_id {
+ Some(id) => id,
+ None => {
+ log::warn!("Batch {} completed but has no output file", batch_id);
+ continue;
+ }
+ };
+
+ let results_content = batches::download_file(
+ self.http_client.as_ref(),
+ OPEN_AI_API_URL,
+ &self.api_key,
+ &output_file_id,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ let results = batches::parse_batch_output(&results_content)
+ .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
+
+ let mut updates: Vec<(String, String)> = Vec::new();
+ let mut success_count = 0;
+
+ for result in results {
+ let request_hash = result
+ .custom_id
+ .strip_prefix("req_hash_")
+ .unwrap_or(&result.custom_id)
+ .to_string();
+
+ if let Some(response_body) = result.response {
+ if response_body.status_code == 200 {
+ let response_json = serde_json::to_string(&response_body.body)?;
+ updates.push((response_json, request_hash));
+ success_count += 1;
+ } else {
+ log::error!(
+ "Batch request {} failed with status {}",
+ request_hash,
+ response_body.status_code
+ );
+ let error_json = serde_json::json!({
+ "error": {
+ "type": "http_error",
+ "status_code": response_body.status_code
+ }
+ })
+ .to_string();
+ updates.push((error_json, request_hash));
+ }
+ } else if let Some(error) = result.error {
+ log::error!(
+ "Batch request {} failed: {}: {}",
+ request_hash,
+ error.code,
+ error.message
+ );
+ let error_json = serde_json::json!({
+ "error": {
+ "type": error.code,
+ "message": error.message
+ }
+ })
+ .to_string();
+ updates.push((error_json, request_hash));
+ }
+ }
+
+ let connection = self.connection.lock().unwrap();
+ connection.with_savepoint("batch_download", || {
+ let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
+ let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
+ for (response_json, request_hash) in &updates {
+ exec((response_json.as_str(), request_hash.as_str()))?;
+ }
+ Ok(())
+ })?;
+ log::info!("Downloaded {} successful requests", success_count);
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn upload_pending_requests(&self) -> Result<Vec<String>> {
+ const BATCH_CHUNK_SIZE: i32 = 16_000;
+ let mut all_batch_ids = Vec::new();
+ let mut total_uploaded = 0;
+
+ loop {
+ let rows: Vec<(String, String)> = {
+ let connection = self.connection.lock().unwrap();
+ let q = sql!(
+ SELECT request_hash, request FROM openai_cache
+ WHERE batch_id IS NULL AND response IS NULL
+ LIMIT ?
+ );
+ connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
+ };
+
+ if rows.is_empty() {
+ break;
+ }
+
+ let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
+
+ let mut jsonl_content = String::new();
+ for (hash, request_str) in &rows {
+ let serializable_request: SerializableRequest =
+ serde_json::from_str(request_str).unwrap();
+
+ let messages: Vec<RequestMessage> = serializable_request
+ .messages
+ .into_iter()
+ .map(|msg| match msg.role.as_str() {
+ "user" => RequestMessage::User {
+ content: MessageContent::Plain(msg.content),
+ },
+ "assistant" => RequestMessage::Assistant {
+ content: Some(MessageContent::Plain(msg.content)),
+ tool_calls: Vec::new(),
+ },
+ "system" => RequestMessage::System {
+ content: MessageContent::Plain(msg.content),
+ },
+ _ => RequestMessage::User {
+ content: MessageContent::Plain(msg.content),
+ },
+ })
+ .collect();
+
+ let request = OpenAiRequest {
+ model: serializable_request.model,
+ messages,
+ stream: false,
+ max_completion_tokens: Some(serializable_request.max_tokens),
+ stop: Vec::new(),
+ temperature: None,
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: Vec::new(),
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ let custom_id = format!("req_hash_{}", hash);
+ let batch_item = batches::BatchRequestItem::new(custom_id, request);
+ let line = batch_item
+ .to_jsonl_line()
+ .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
+ jsonl_content.push_str(&line);
+ jsonl_content.push('\n');
+ }
+
+ let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
+ let file_obj = batches::upload_batch_file(
+ self.http_client.as_ref(),
+ OPEN_AI_API_URL,
+ &self.api_key,
+ &filename,
+ jsonl_content.into_bytes(),
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
+
+ let batch = batches::create_batch(
+ self.http_client.as_ref(),
+ OPEN_AI_API_URL,
+ &self.api_key,
+ batches::CreateBatchRequest::new(file_obj.id),
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
+
+ {
+ let connection = self.connection.lock().unwrap();
+ connection.with_savepoint("batch_upload", || {
+ let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
+ let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
+ for hash in &request_hashes {
+ exec((batch.id.as_str(), hash.as_str()))?;
+ }
+ Ok(())
+ })?;
+ }
+
+ let batch_len = rows.len();
+ total_uploaded += batch_len;
+ log::info!(
+ "Uploaded batch {} with {} requests ({} total)",
+ batch.id,
+ batch_len,
+ total_uploaded
+ );
+
+ all_batch_ids.push(batch.id);
+ }
+
+ if !all_batch_ids.is_empty() {
+ log::info!(
+ "Finished uploading {} batches with {} total requests",
+ all_batch_ids.len(),
+ total_uploaded
+ );
+ }
+
+ Ok(all_batch_ids)
+ }
+
+ fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String {
+ let mut hasher = std::hash::DefaultHasher::new();
+ "openai".hash(&mut hasher);
+ model.hash(&mut hasher);
+ max_tokens.hash(&mut hasher);
+ for msg in messages {
+ message_content_to_string(msg).hash(&mut hasher);
+ }
+ let request_hash = hasher.finish();
+ format!("{request_hash:016x}")
+ }
+}
+
+fn message_role_to_string(msg: &RequestMessage) -> String {
+ match msg {
+ RequestMessage::User { .. } => "user".to_string(),
+ RequestMessage::Assistant { .. } => "assistant".to_string(),
+ RequestMessage::System { .. } => "system".to_string(),
+ RequestMessage::Tool { .. } => "tool".to_string(),
+ }
+}
+
+fn message_content_to_string(msg: &RequestMessage) -> String {
+ match msg {
+ RequestMessage::User { content } => content_to_string(content),
+ RequestMessage::Assistant { content, .. } => {
+ content.as_ref().map(content_to_string).unwrap_or_default()
+ }
+ RequestMessage::System { content } => content_to_string(content),
+ RequestMessage::Tool { content, .. } => content_to_string(content),
+ }
+}
+
+fn content_to_string(content: &MessageContent) -> String {
+ match content {
+ MessageContent::Plain(text) => text.clone(),
+ MessageContent::Multipart(parts) => parts
+ .iter()
+ .filter_map(|part| match part {
+ open_ai::MessagePart::Text { text } => Some(text.clone()),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n"),
+ }
+}
+
+pub enum OpenAiClient {
+ Plain(PlainOpenAiClient),
+ Batch(BatchingOpenAiClient),
+ #[allow(dead_code)]
+ Dummy,
+}
+
+impl OpenAiClient {
+ pub fn plain() -> Result<Self> {
+ Ok(Self::Plain(PlainOpenAiClient::new()?))
+ }
+
+ pub fn batch(cache_path: &Path) -> Result<Self> {
+ Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
+ }
+
+ #[allow(dead_code)]
+ pub fn dummy() -> Self {
+ Self::Dummy
+ }
+
+ pub async fn generate(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: Vec<RequestMessage>,
+ ) -> Result<Option<OpenAiResponse>> {
+ match self {
+ OpenAiClient::Plain(plain_client) => plain_client
+ .generate(model, max_tokens, messages)
+ .await
+ .map(Some),
+ OpenAiClient::Batch(batching_client) => {
+ batching_client.generate(model, max_tokens, messages).await
+ }
+ OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
+ }
+ }
+
+ pub async fn sync_batches(&self) -> Result<()> {
+ match self {
+ OpenAiClient::Plain(_) => Ok(()),
+ OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
+ OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
+ }
+ }
+
+ pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
+ match self {
+ OpenAiClient::Plain(_) => {
+ anyhow::bail!("Import batches is only supported with batching client")
+ }
+ OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
+ OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
+ }
+ }
+}
@@ -1,10 +1,11 @@
use crate::{
- FormatPromptArgs, PredictArgs, PredictionProvider,
+ FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction, ExamplePrompt},
format_prompt::{TeacherPrompt, run_format_prompt},
headless::EpAppState,
load_project::run_load_project,
+ openai_client::OpenAiClient,
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
progress::{ExampleProgress, InfoStyle, Step},
retrieve_context::run_context_retrieval,
@@ -20,9 +21,9 @@ use std::{
atomic::{AtomicUsize, Ordering::SeqCst},
},
};
-use zeta_prompt::ZetaVersion;
static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
+static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
pub async fn run_prediction(
example: &mut Example,
@@ -53,7 +54,7 @@ pub async fn run_prediction(
run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
- if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
+ if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
provider
{
let _step_progress = example_progress.start(Step::Predict);
@@ -68,7 +69,7 @@ pub async fn run_prediction(
.await?;
let batched = matches!(provider, PredictionProvider::Teacher(..));
- return predict_anthropic(example, repetition_count, version, batched).await;
+ return predict_teacher(example, backend, batched).await;
}
run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -148,6 +149,7 @@ pub async fn run_prediction(
updated_example.prompt.get_or_insert(ExamplePrompt {
input: prompt,
expected_output: String::new(),
+ rejected_output: None,
provider,
});
}
@@ -255,13 +257,23 @@ pub async fn run_prediction(
Ok(())
}
+async fn predict_teacher(
+ example: &mut Example,
+ backend: TeacherBackend,
+ batched: bool,
+) -> anyhow::Result<()> {
+ match backend {
+ TeacherBackend::Sonnet45 => predict_anthropic(example, backend, batched).await,
+ TeacherBackend::Gpt52 => predict_openai(example, backend, batched).await,
+ }
+}
+
async fn predict_anthropic(
example: &mut Example,
- _repetition_count: usize,
- version: ZetaVersion,
+ backend: TeacherBackend,
batched: bool,
) -> anyhow::Result<()> {
- let llm_model_name = "claude-sonnet-4-5";
+ let llm_model_name = backend.model_name();
let max_tokens = 16384;
let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
let client = if batched {
@@ -300,16 +312,83 @@ async fn predict_anthropic(
.collect::<Vec<String>>()
.join("\n");
- let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
+ let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch: Some(actual_patch),
actual_output,
error: None,
provider: if batched {
- PredictionProvider::Teacher(version)
+ PredictionProvider::Teacher(backend)
} else {
- PredictionProvider::TeacherNonBatching(version)
+ PredictionProvider::TeacherNonBatching(backend)
+ },
+ };
+
+ example.predictions.push(prediction);
+ Ok(())
+}
+
+async fn predict_openai(
+ example: &mut Example,
+ backend: TeacherBackend,
+ batched: bool,
+) -> anyhow::Result<()> {
+ let llm_model_name = backend.model_name();
+ let max_tokens = 16384;
+ let llm_client = OPENAI_CLIENT.get_or_init(|| {
+ let client = if batched {
+ OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
+ } else {
+ OpenAiClient::plain()
+ };
+ client.expect("Failed to create OpenAI client")
+ });
+
+ let prompt = example.prompt.as_ref().context("Prompt is required")?;
+
+ let messages = vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt.input.clone()),
+ }];
+
+ let Some(response) = llm_client
+ .generate(llm_model_name, max_tokens, messages)
+ .await?
+ else {
+ // Request stashed for batched processing
+ return Ok(());
+ };
+
+ let actual_output = response
+ .choices
+ .into_iter()
+ .filter_map(|choice| match choice.message {
+ open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
+ open_ai::MessageContent::Plain(text) => text,
+ open_ai::MessageContent::Multipart(parts) => parts
+ .into_iter()
+ .filter_map(|p| match p {
+ open_ai::MessagePart::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join(""),
+ }),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n");
+
+ let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+ let prediction = ExamplePrediction {
+ actual_patch: Some(actual_patch),
+ actual_output,
+ error: None,
+ provider: if batched {
+ PredictionProvider::Teacher(backend)
+ } else {
+ PredictionProvider::TeacherNonBatching(backend)
},
};
@@ -319,16 +398,28 @@ async fn predict_anthropic(
pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
match provider {
- Some(PredictionProvider::Teacher(..)) => {
- let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
- AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
- .expect("Failed to create Anthropic client")
- });
- llm_client
- .sync_batches()
- .await
- .context("Failed to sync batches")?;
- }
+ Some(PredictionProvider::Teacher(backend)) => match backend {
+ TeacherBackend::Sonnet45 => {
+ let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
+ AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
+ .expect("Failed to create Anthropic client")
+ });
+ llm_client
+ .sync_batches()
+ .await
+ .context("Failed to sync Anthropic batches")?;
+ }
+ TeacherBackend::Gpt52 => {
+ let llm_client = OPENAI_CLIENT.get_or_init(|| {
+ OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
+ .expect("Failed to create OpenAI client")
+ });
+ llm_client
+ .sync_batches()
+ .await
+ .context("Failed to sync OpenAI batches")?;
+ }
+ },
_ => (),
};
Ok(())
@@ -1,22 +1,20 @@
//! Quality assessment of predictions using LLM-as-a-judge.
//!
-//! This module uses the Anthropic Batch API to evaluate prediction quality.
-//! Caching is handled by the underlying AnthropicClient.
+//! This module uses LLM Batch APIs to evaluate prediction quality.
+//! Caching is handled by the underlying client.
+use crate::BatchProvider;
use crate::anthropic_client::AnthropicClient;
use crate::example::Example;
use crate::format_prompt::extract_cursor_excerpt_from_example;
+use crate::openai_client::OpenAiClient;
use crate::paths::LLM_CACHE_DB;
use crate::word_diff::unified_to_word_diff;
-use anthropic::{Message, RequestContent, Role};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::io::{BufWriter, Write};
use std::path::PathBuf;
-/// Model to use for QA evaluation.
-const MODEL: &str = "claude-sonnet-4-5";
-
const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
/// Arguments for the QA command.
@@ -29,6 +27,17 @@ pub struct QaArgs {
/// Wait for batch to complete (polls every 30s)
#[clap(long)]
pub wait: bool,
+
+ /// Which LLM provider to use (anthropic or openai)
+ #[clap(long, default_value = "anthropic")]
+ pub backend: BatchProvider,
+}
+
+fn model_for_backend(backend: BatchProvider) -> &'static str {
+ match backend {
+ BatchProvider::Anthropic => "claude-sonnet-4-5",
+ BatchProvider::Openai => "gpt-5.2",
+ }
}
/// Result of QA evaluation for a single prediction.
@@ -147,19 +156,101 @@ fn parse_response(response_text: &str) -> QaResult {
}
}
+enum QaClient {
+ Anthropic(AnthropicClient),
+ OpenAi(OpenAiClient),
+}
+
+impl QaClient {
+ async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
+ match self {
+ QaClient::Anthropic(client) => {
+ let messages = vec![anthropic::Message {
+ role: anthropic::Role::User,
+ content: vec![anthropic::RequestContent::Text {
+ text: prompt.to_string(),
+ cache_control: None,
+ }],
+ }];
+ let response = client.generate(model, max_tokens, messages).await?;
+ Ok(response.map(|r| {
+ r.content
+ .iter()
+ .filter_map(|c| match c {
+ anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
+ }))
+ }
+ QaClient::OpenAi(client) => {
+ let messages = vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt.to_string()),
+ }];
+ let response = client.generate(model, max_tokens, messages).await?;
+ Ok(response.map(|r| {
+ r.choices
+ .into_iter()
+ .filter_map(|choice| match choice.message {
+ open_ai::RequestMessage::Assistant { content, .. } => {
+ content.map(|c| match c {
+ open_ai::MessageContent::Plain(text) => text,
+ open_ai::MessageContent::Multipart(parts) => parts
+ .into_iter()
+ .filter_map(|p| match p {
+ open_ai::MessagePart::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join(""),
+ })
+ }
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
+ }))
+ }
+ }
+ }
+
+ async fn sync_batches(&self) -> Result<()> {
+ match self {
+ QaClient::Anthropic(client) => client.sync_batches().await,
+ QaClient::OpenAi(client) => client.sync_batches().await,
+ }
+ }
+}
+
/// Run the QA evaluation on a set of examples.
pub async fn run_qa(
examples: &mut [Example],
args: &QaArgs,
output_path: Option<&PathBuf>,
) -> Result<()> {
- let client = if args.no_batch {
- AnthropicClient::plain()?
- } else {
- AnthropicClient::batch(&LLM_CACHE_DB)?
+ let model = model_for_backend(args.backend);
+ let client = match args.backend {
+ BatchProvider::Anthropic => {
+ if args.no_batch {
+ QaClient::Anthropic(AnthropicClient::plain()?)
+ } else {
+ QaClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
+ }
+ }
+ BatchProvider::Openai => {
+ if args.no_batch {
+ QaClient::OpenAi(OpenAiClient::plain()?)
+ } else {
+ QaClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
+ }
+ }
};
- eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
+ eprintln!(
+ "Using model: {}, backend: {:?}, batching: {}",
+ model, args.backend, !args.no_batch
+ );
// First pass: send requests (client handles caching internally)
let mut prompts: Vec<(usize, String)> = Vec::new();
@@ -187,54 +278,16 @@ pub async fn run_qa(
for (i, (idx, prompt)) in prompts.iter().enumerate() {
eprint!("\rProcessing {}/{}", i + 1, prompts.len());
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let response = client.generate(MODEL, 1024, messages).await?;
- let result = response.map(|r| {
- let text = r
- .content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => Some(text.as_str()),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("");
- parse_response(&text)
- });
+ let response = client.generate(model, 1024, prompt).await?;
+ let result = response.map(|text| parse_response(&text));
results.push((*idx, result));
}
eprintln!();
} else {
// Queue all for batching
for (idx, prompt) in &prompts {
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let response = client.generate(MODEL, 1024, messages).await?;
- let result = response.map(|r| {
- let text = r
- .content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => Some(text.as_str()),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("");
- parse_response(&text)
- });
+ let response = client.generate(model, 1024, prompt).await?;
+ let result = response.map(|text| parse_response(&text));
results.push((*idx, result));
}
@@ -251,27 +304,8 @@ pub async fn run_qa(
let mut all_done = true;
for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
if results[result_idx].1.is_none() {
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let response = client.generate(MODEL, 1024, messages).await?;
- if let Some(r) = response {
- let text = r
- .content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => {
- Some(text.as_str())
- }
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("");
+ let response = client.generate(model, 1024, prompt).await?;
+ if let Some(text) = response {
results[result_idx] = (*idx, Some(parse_response(&text)));
} else {
all_done = false;
@@ -4,20 +4,18 @@
//! predictions that need improvement (based on reverts_edits or low confidence),
//! and uses an LLM to generate improved predictions.
+use crate::BatchProvider;
use crate::PredictionProvider;
use crate::anthropic_client::AnthropicClient;
use crate::example::{Example, ExamplePrediction};
use crate::format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example};
+use crate::openai_client::OpenAiClient;
use crate::paths::LLM_CACHE_DB;
use crate::word_diff::unified_to_word_diff;
-use anthropic::{Message, RequestContent, Role};
use anyhow::Result;
use std::io::{BufWriter, Write};
use std::path::PathBuf;
-/// Model to use for repair.
-const MODEL: &str = "claude-sonnet-4-5";
-
const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md");
/// Arguments for the repair command.
@@ -34,6 +32,17 @@ pub struct RepairArgs {
/// Confidence threshold: repair predictions with confidence <= this value (1-5)
#[clap(long, default_value = "2")]
pub confidence_threshold: u8,
+
+ /// Which LLM provider to use (anthropic or openai)
+ #[clap(long, default_value = "anthropic")]
+ pub backend: BatchProvider,
+}
+
+fn model_for_backend(backend: BatchProvider) -> &'static str {
+ match backend {
+ BatchProvider::Anthropic => "claude-sonnet-4-5",
+ BatchProvider::Openai => "gpt-5.2",
+ }
}
/// Build the repair prompt for an example that needs improvement.
@@ -127,21 +136,100 @@ fn parse_repair_response(example: &Example, response_text: &str) -> Result<Examp
})
}
+enum RepairClient {
+ Anthropic(AnthropicClient),
+ OpenAi(OpenAiClient),
+}
+
+impl RepairClient {
+ async fn generate(&self, model: &str, max_tokens: u64, prompt: &str) -> Result<Option<String>> {
+ match self {
+ RepairClient::Anthropic(client) => {
+ let messages = vec![anthropic::Message {
+ role: anthropic::Role::User,
+ content: vec![anthropic::RequestContent::Text {
+ text: prompt.to_string(),
+ cache_control: None,
+ }],
+ }];
+ let response = client.generate(model, max_tokens, messages).await?;
+ Ok(response.map(|r| {
+ r.content
+ .iter()
+ .filter_map(|c| match c {
+ anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
+ }))
+ }
+ RepairClient::OpenAi(client) => {
+ let messages = vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt.to_string()),
+ }];
+ let response = client.generate(model, max_tokens, messages).await?;
+ Ok(response.map(|r| {
+ r.choices
+ .into_iter()
+ .filter_map(|choice| match choice.message {
+ open_ai::RequestMessage::Assistant { content, .. } => {
+ content.map(|c| match c {
+ open_ai::MessageContent::Plain(text) => text,
+ open_ai::MessageContent::Multipart(parts) => parts
+ .into_iter()
+ .filter_map(|p| match p {
+ open_ai::MessagePart::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join(""),
+ })
+ }
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join("")
+ }))
+ }
+ }
+ }
+
+ async fn sync_batches(&self) -> Result<()> {
+ match self {
+ RepairClient::Anthropic(client) => client.sync_batches().await,
+ RepairClient::OpenAi(client) => client.sync_batches().await,
+ }
+ }
+}
+
/// Run the repair process on a set of examples.
pub async fn run_repair(
examples: &mut [Example],
args: &RepairArgs,
output_path: Option<&PathBuf>,
) -> Result<()> {
- let client = if args.no_batch {
- AnthropicClient::plain()?
- } else {
- AnthropicClient::batch(&LLM_CACHE_DB)?
+ let model = model_for_backend(args.backend);
+ let client = match args.backend {
+ BatchProvider::Anthropic => {
+ if args.no_batch {
+ RepairClient::Anthropic(AnthropicClient::plain()?)
+ } else {
+ RepairClient::Anthropic(AnthropicClient::batch(&LLM_CACHE_DB)?)
+ }
+ }
+ BatchProvider::Openai => {
+ if args.no_batch {
+ RepairClient::OpenAi(OpenAiClient::plain()?)
+ } else {
+ RepairClient::OpenAi(OpenAiClient::batch(&LLM_CACHE_DB)?)
+ }
+ }
};
eprintln!(
- "Using model: {}, batching: {}, confidence_threshold: {}",
- MODEL, !args.no_batch, args.confidence_threshold
+ "Using model: {}, backend: {:?}, batching: {}, confidence_threshold: {}",
+ model, args.backend, !args.no_batch, args.confidence_threshold
);
// First pass: identify examples that need repair and build prompts
@@ -185,51 +273,15 @@ pub async fn run_repair(
for (i, (idx, prompt)) in repair_items.iter().enumerate() {
eprint!("\rProcessing {}/{}", i + 1, repair_items.len());
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let response = client.generate(MODEL, 16384, messages).await?;
- let result = response.map(|r| {
- r.content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => Some(text.as_str()),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("")
- });
- results.push((*idx, result));
+ let response = client.generate(model, 16384, prompt).await?;
+ results.push((*idx, response));
}
eprintln!();
} else {
// Queue all for batching
for (idx, prompt) in &repair_items {
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let response = client.generate(MODEL, 16384, messages).await?;
- let result = response.map(|r| {
- r.content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => Some(text.as_str()),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("")
- });
- results.push((*idx, result));
+ let response = client.generate(model, 16384, prompt).await?;
+ results.push((*idx, response));
}
// Sync batches (upload pending, download finished)
@@ -245,27 +297,8 @@ pub async fn run_repair(
let mut all_done = true;
for (result_idx, (idx, prompt)) in repair_items.iter().enumerate() {
if results[result_idx].1.is_none() {
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let response = client.generate(MODEL, 16384, messages).await?;
- if let Some(r) = response {
- let text = r
- .content
- .iter()
- .filter_map(|c| match c {
- anthropic::ResponseContent::Text { text } => {
- Some(text.as_str())
- }
- _ => None,
- })
- .collect::<Vec<_>>()
- .join("");
+ let response = client.generate(model, 16384, prompt).await?;
+ if let Some(text) = response {
results[result_idx] = (*idx, Some(text));
} else {
all_done = false;
@@ -19,6 +19,7 @@ schemars = ["dep:schemars"]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
+rand.workspace = true
schemars = { workspace = true, optional = true }
log.workspace = true
serde.workspace = true
@@ -0,0 +1,332 @@
+use anyhow::Result;
+use futures::AsyncReadExt;
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+
+use crate::{Request, RequestError, Response};
+
+/// A single request within a batch
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchRequestItem {
+ pub custom_id: String,
+ pub method: String,
+ pub url: String,
+ pub body: Request,
+}
+
+impl BatchRequestItem {
+ pub fn new(custom_id: String, request: Request) -> Self {
+ Self {
+ custom_id,
+ method: "POST".to_string(),
+ url: "/v1/chat/completions".to_string(),
+ body: request,
+ }
+ }
+
+ pub fn to_jsonl_line(&self) -> Result<String, serde_json::Error> {
+ serde_json::to_string(self)
+ }
+}
+
+/// Request to create a batch
+#[derive(Debug, Serialize)]
+pub struct CreateBatchRequest {
+ pub input_file_id: String,
+ pub endpoint: String,
+ pub completion_window: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub metadata: Option<serde_json::Value>,
+}
+
+impl CreateBatchRequest {
+ pub fn new(input_file_id: String) -> Self {
+ Self {
+ input_file_id,
+ endpoint: "/v1/chat/completions".to_string(),
+ completion_window: "24h".to_string(),
+ metadata: None,
+ }
+ }
+}
+
+/// Response from batch creation or retrieval
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Batch {
+ pub id: String,
+ pub object: String,
+ pub endpoint: String,
+ pub input_file_id: String,
+ pub completion_window: String,
+ pub status: String,
+ pub output_file_id: Option<String>,
+ pub error_file_id: Option<String>,
+ pub created_at: u64,
+ #[serde(default)]
+ pub in_progress_at: Option<u64>,
+ #[serde(default)]
+ pub expires_at: Option<u64>,
+ #[serde(default)]
+ pub finalizing_at: Option<u64>,
+ #[serde(default)]
+ pub completed_at: Option<u64>,
+ #[serde(default)]
+ pub failed_at: Option<u64>,
+ #[serde(default)]
+ pub expired_at: Option<u64>,
+ #[serde(default)]
+ pub cancelling_at: Option<u64>,
+ #[serde(default)]
+ pub cancelled_at: Option<u64>,
+ #[serde(default)]
+ pub request_counts: Option<BatchRequestCounts>,
+ #[serde(default)]
+ pub metadata: Option<serde_json::Value>,
+}
+
+#[derive(Debug, Serialize, Deserialize, Default)]
+pub struct BatchRequestCounts {
+ pub total: u64,
+ pub completed: u64,
+ pub failed: u64,
+}
+
+/// Response from file upload
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FileObject {
+ pub id: String,
+ pub object: String,
+ pub bytes: u64,
+ pub created_at: u64,
+ pub filename: String,
+ pub purpose: String,
+}
+
+/// Individual result from batch output
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchOutputItem {
+ pub id: String,
+ pub custom_id: String,
+ pub response: Option<BatchResponseBody>,
+ pub error: Option<BatchError>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchResponseBody {
+ pub status_code: u16,
+ pub body: Response,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchError {
+ pub code: String,
+ pub message: String,
+}
+
+/// Upload a JSONL file for batch processing
+pub async fn upload_batch_file(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ filename: &str,
+ content: Vec<u8>,
+) -> Result<FileObject, RequestError> {
+ let uri = format!("{api_url}/files");
+
+ let boundary = format!("----WebKitFormBoundary{:x}", rand::random::<u64>());
+
+ let mut body = Vec::new();
+ body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
+ body.extend_from_slice(b"Content-Disposition: form-data; name=\"purpose\"\r\n\r\n");
+ body.extend_from_slice(b"batch\r\n");
+ body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
+ body.extend_from_slice(
+ format!("Content-Disposition: form-data; name=\"file\"; filename=\"{filename}\"\r\n")
+ .as_bytes(),
+ );
+ body.extend_from_slice(b"Content-Type: application/jsonl\r\n\r\n");
+ body.extend_from_slice(&content);
+ body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes());
+
+ let request = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Authorization", format!("Bearer {}", api_key.trim()))
+ .header(
+ "Content-Type",
+ format!("multipart/form-data; boundary={boundary}"),
+ )
+ .body(AsyncBody::from(body))
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+ } else {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ Err(RequestError::HttpResponseError {
+ provider: "openai".to_owned(),
+ status_code: response.status(),
+ body,
+ headers: response.headers().clone(),
+ })
+ }
+}
+
+/// Create a batch from an uploaded file
+pub async fn create_batch(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: CreateBatchRequest,
+) -> Result<Batch, RequestError> {
+ let uri = format!("{api_url}/batches");
+
+ let serialized = serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?;
+
+ let request = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Authorization", format!("Bearer {}", api_key.trim()))
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(serialized))
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+ } else {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ Err(RequestError::HttpResponseError {
+ provider: "openai".to_owned(),
+ status_code: response.status(),
+ body,
+ headers: response.headers().clone(),
+ })
+ }
+}
+
+/// Retrieve batch status
+pub async fn retrieve_batch(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ batch_id: &str,
+) -> Result<Batch, RequestError> {
+ let uri = format!("{api_url}/batches/{batch_id}");
+
+ let request = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Authorization", format!("Bearer {}", api_key.trim()))
+ .body(AsyncBody::default())
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+ } else {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ Err(RequestError::HttpResponseError {
+ provider: "openai".to_owned(),
+ status_code: response.status(),
+ body,
+ headers: response.headers().clone(),
+ })
+ }
+}
+
+/// Download file content (for batch results)
+pub async fn download_file(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ file_id: &str,
+) -> Result<String, RequestError> {
+ let uri = format!("{api_url}/files/{file_id}/content");
+
+ let request = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Authorization", format!("Bearer {}", api_key.trim()))
+ .body(AsyncBody::default())
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ Ok(body)
+ } else {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ Err(RequestError::HttpResponseError {
+ provider: "openai".to_owned(),
+ status_code: response.status(),
+ body,
+ headers: response.headers().clone(),
+ })
+ }
+}
+
+/// Parse batch output JSONL into individual results
+pub fn parse_batch_output(content: &str) -> Result<Vec<BatchOutputItem>, serde_json::Error> {
+ content
+ .lines()
+ .filter(|line| !line.trim().is_empty())
+ .map(|line| serde_json::from_str(line))
+ .collect()
+}
@@ -1,3 +1,4 @@
+pub mod batches;
pub mod responses;
use anyhow::{Context as _, Result, anyhow};
@@ -529,6 +530,52 @@ pub struct ResponseStreamEvent {
pub usage: Option<Usage>,
}
+pub async fn non_streaming_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+) -> Result<Response, RequestError> {
+ let uri = format!("{api_url}/chat/completions");
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key.trim()));
+
+ let request = request_builder
+ .body(AsyncBody::from(
+ serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
+ ))
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
+ } else {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(|e| RequestError::Other(e.into()))?;
+
+ Err(RequestError::HttpResponseError {
+ provider: "openai".to_owned(),
+ status_code: response.status(),
+ body,
+ headers: response.headers().clone(),
+ })
+ }
+}
+
pub async fn stream_completion(
client: &dyn HttpClient,
provider_name: &str,