@@ -14,7 +14,7 @@ use sqlez_macros::sql;
use std::hash::Hash;
use std::hash::Hasher;
use std::path::Path;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
pub struct PlainLlmClient {
pub http_client: Arc<dyn HttpClient>,
@@ -134,7 +134,7 @@ impl PlainLlmClient {
}
pub struct BatchingLlmClient {
- connection: sqlez::connection::Connection,
+ connection: Mutex<sqlez::connection::Connection>,
http_client: Arc<dyn HttpClient>,
api_key: String,
}
@@ -197,7 +197,7 @@ impl BatchingLlmClient {
drop(statement);
Ok(Self {
- connection,
+ connection: Mutex::new(connection),
http_client,
api_key,
})
@@ -210,7 +210,8 @@ impl BatchingLlmClient {
messages: &[Message],
) -> Result<Option<AnthropicResponse>> {
let request_hash_str = Self::request_hash(model, max_tokens, messages);
- let response: Vec<String> = self.connection.select_bound(
+ let connection = self.connection.lock().unwrap();
+ let response: Vec<String> = connection.select_bound(
&sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
)?(request_hash_str.as_str())?;
Ok(response
@@ -246,7 +247,8 @@ impl BatchingLlmClient {
response: None,
batch_id: None,
};
- self.connection.exec_bound(sql!(
+ let connection = self.connection.lock().unwrap();
+ connection.exec_bound::<CacheRow>(sql!(
INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
cache_row,
)
@@ -275,10 +277,13 @@ impl BatchingLlmClient {
}
async fn download_finished_batches(&self) -> Result<()> {
- let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
- let batch_ids: Vec<String> = self.connection.select(q)?()?;
+ let batch_ids: Vec<String> = {
+ let connection = self.connection.lock().unwrap();
+ let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
+ connection.select(q)?()?
+ };
- for batch_id in batch_ids {
+ for batch_id in &batch_ids {
let batch_status = anthropic::batches::retrieve_batch(
self.http_client.as_ref(),
ANTHROPIC_API_URL,
@@ -360,9 +365,10 @@ impl BatchingLlmClient {
}
}
- self.connection.with_savepoint("batch_download", || {
+ let connection = self.connection.lock().unwrap();
+ connection.with_savepoint("batch_download", || {
let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
- let mut exec = self.connection.exec_bound::<(&str, &str)>(q)?;
+ let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
for (response_json, request_hash) in &updates {
exec((response_json.as_str(), request_hash.as_str()))?;
}
@@ -376,11 +382,13 @@ impl BatchingLlmClient {
}
async fn upload_pending_requests(&self) -> Result<String> {
- let q = sql!(
- SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
- );
-
- let rows: Vec<(String, String)> = self.connection.select(q)?()?;
+ let rows: Vec<(String, String)> = {
+ let connection = self.connection.lock().unwrap();
+ let q = sql!(
+ SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
+ );
+ connection.select(q)?()?
+ };
if rows.is_empty() {
return Ok(String::new());
@@ -440,10 +448,13 @@ impl BatchingLlmClient {
.await
.map_err(|e| anyhow::anyhow!("{:?}", e))?;
- let q = sql!(
- UPDATE cache SET batch_id = ? WHERE batch_id is NULL
- );
- self.connection.exec_bound(q)?(batch.id.as_str())?;
+ {
+ let connection = self.connection.lock().unwrap();
+ let q = sql!(
+ UPDATE cache SET batch_id = ? WHERE batch_id is NULL
+ );
+ connection.exec_bound(q)?(batch.id.as_str())?;
+ }
log::info!("Uploaded batch with {} requests", batch_len);
@@ -21,6 +21,8 @@ use std::{
},
};
+static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
+
pub async fn run_prediction(
example: &mut Example,
provider: Option<PredictionProvider>,
@@ -233,12 +235,14 @@ async fn predict_anthropic(
) -> anyhow::Result<()> {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
- let llm_client = if batched {
- AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
- } else {
- AnthropicClient::plain()
- };
- let llm_client = llm_client.context("Failed to create LLM client")?;
+ let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
+ let client = if batched {
+ AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
+ } else {
+ AnthropicClient::plain()
+ };
+ client.expect("Failed to create Anthropic client")
+ });
let prompt = example.prompt.as_ref().context("Prompt is required")?;
@@ -283,9 +287,10 @@ async fn predict_anthropic(
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
match provider {
PredictionProvider::Teacher => {
- let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
- let llm_client =
- AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
+ let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
+ AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
+ .expect("Failed to create Anthropic client")
+ });
llm_client
.sync_batches()
.await