diff --git a/crates/edit_prediction_cli/src/anthropic_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs index 4ba826617c40e2e5f71530e5b25eeb42e1351bb9..e7723926860d80e70cb3f3a40a47d1a9a0a490c7 100644 --- a/crates/edit_prediction_cli/src/anthropic_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -14,7 +14,7 @@ use sqlez_macros::sql; use std::hash::Hash; use std::hash::Hasher; use std::path::Path; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; pub struct PlainLlmClient { pub http_client: Arc, @@ -134,7 +134,7 @@ impl PlainLlmClient { } pub struct BatchingLlmClient { - connection: sqlez::connection::Connection, + connection: Mutex, http_client: Arc, api_key: String, } @@ -197,7 +197,7 @@ impl BatchingLlmClient { drop(statement); Ok(Self { - connection, + connection: Mutex::new(connection), http_client, api_key, }) @@ -210,7 +210,8 @@ impl BatchingLlmClient { messages: &[Message], ) -> Result> { let request_hash_str = Self::request_hash(model, max_tokens, messages); - let response: Vec = self.connection.select_bound( + let connection = self.connection.lock().unwrap(); + let response: Vec = connection.select_bound( &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;), )?(request_hash_str.as_str())?; Ok(response @@ -246,7 +247,8 @@ impl BatchingLlmClient { response: None, batch_id: None, }; - self.connection.exec_bound(sql!( + let connection = self.connection.lock().unwrap(); + connection.exec_bound::(sql!( INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?( cache_row, ) @@ -275,10 +277,13 @@ impl BatchingLlmClient { } async fn download_finished_batches(&self) -> Result<()> { - let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL); - let batch_ids: Vec = self.connection.select(q)?()?; + let batch_ids: Vec = { + let connection = self.connection.lock().unwrap(); + let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL); + connection.select(q)?()? + }; - for batch_id in batch_ids { + for batch_id in &batch_ids { let batch_status = anthropic::batches::retrieve_batch( self.http_client.as_ref(), ANTHROPIC_API_URL, @@ -360,9 +365,10 @@ impl BatchingLlmClient { } } - self.connection.with_savepoint("batch_download", || { + let connection = self.connection.lock().unwrap(); + connection.with_savepoint("batch_download", || { let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?); - let mut exec = self.connection.exec_bound::<(&str, &str)>(q)?; + let mut exec = connection.exec_bound::<(&str, &str)>(q)?; for (response_json, request_hash) in &updates { exec((response_json.as_str(), request_hash.as_str()))?; } @@ -376,11 +382,13 @@ impl BatchingLlmClient { } async fn upload_pending_requests(&self) -> Result { - let q = sql!( - SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL - ); - - let rows: Vec<(String, String)> = self.connection.select(q)?()?; + let rows: Vec<(String, String)> = { + let connection = self.connection.lock().unwrap(); + let q = sql!( + SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL + ); + connection.select(q)?()? + }; if rows.is_empty() { return Ok(String::new()); @@ -440,10 +448,13 @@ impl BatchingLlmClient { .await .map_err(|e| anyhow::anyhow!("{:?}", e))?; - let q = sql!( - UPDATE cache SET batch_id = ? WHERE batch_id is NULL - ); - self.connection.exec_bound(q)?(batch.id.as_str())?; + { + let connection = self.connection.lock().unwrap(); + let q = sql!( + UPDATE cache SET batch_id = ? WHERE batch_id is NULL + ); + connection.exec_bound(q)?(batch.id.as_str())?; + } log::info!("Uploaded batch with {} requests", batch_len); diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 04a58ee2e7b66f2ce40db626088baec898262405..7a53bea32f68245fe45fe3a9aea16aa240b6655d 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -21,6 +21,8 @@ use std::{ }, }; +static ANTHROPIC_CLIENT: OnceLock = OnceLock::new(); + pub async fn run_prediction( example: &mut Example, provider: Option, @@ -233,12 +235,14 @@ async fn predict_anthropic( ) -> anyhow::Result<()> { let llm_model_name = "claude-sonnet-4-5"; let max_tokens = 16384; - let llm_client = if batched { - AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref()) - } else { - AnthropicClient::plain() - }; - let llm_client = llm_client.context("Failed to create LLM client")?; + let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { + let client = if batched { + AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) + } else { + AnthropicClient::plain() + }; + client.expect("Failed to create Anthropic client") + }); let prompt = example.prompt.as_ref().context("Prompt is required")?; @@ -283,9 +287,10 @@ async fn predict_anthropic( pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> { match provider { PredictionProvider::Teacher => { - let cache_path = crate::paths::LLM_CACHE_DB.as_ref(); - let llm_client = - AnthropicClient::batch(cache_path).context("Failed to create LLM client")?; + let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { + AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) + .expect("Failed to create Anthropic client") + }); llm_client .sync_batches() .await