ep: Cache Anthropic client (#46406)

Max Brunsfeld and Ben Kunkle created

This makes running `predict` with the teacher model much faster, when
there are many examples.

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

crates/edit_prediction_cli/src/anthropic_client.rs | 49 +++++++++------
crates/edit_prediction_cli/src/predict.rs          | 23 ++++--
2 files changed, 44 insertions(+), 28 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -14,7 +14,7 @@ use sqlez_macros::sql;
 use std::hash::Hash;
 use std::hash::Hasher;
 use std::path::Path;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
 
 pub struct PlainLlmClient {
     pub http_client: Arc<dyn HttpClient>,
@@ -134,7 +134,7 @@ impl PlainLlmClient {
 }
 
 pub struct BatchingLlmClient {
-    connection: sqlez::connection::Connection,
+    connection: Mutex<sqlez::connection::Connection>,
     http_client: Arc<dyn HttpClient>,
     api_key: String,
 }
@@ -197,7 +197,7 @@ impl BatchingLlmClient {
         drop(statement);
 
         Ok(Self {
-            connection,
+            connection: Mutex::new(connection),
             http_client,
             api_key,
         })
@@ -210,7 +210,8 @@ impl BatchingLlmClient {
         messages: &[Message],
     ) -> Result<Option<AnthropicResponse>> {
         let request_hash_str = Self::request_hash(model, max_tokens, messages);
-        let response: Vec<String> = self.connection.select_bound(
+        let connection = self.connection.lock().unwrap();
+        let response: Vec<String> = connection.select_bound(
             &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
         )?(request_hash_str.as_str())?;
         Ok(response
@@ -246,7 +247,8 @@ impl BatchingLlmClient {
             response: None,
             batch_id: None,
         };
-        self.connection.exec_bound(sql!(
+        let connection = self.connection.lock().unwrap();
+        connection.exec_bound::<CacheRow>(sql!(
             INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
             cache_row,
         )
@@ -275,10 +277,13 @@ impl BatchingLlmClient {
     }
 
     async fn download_finished_batches(&self) -> Result<()> {
-        let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
-        let batch_ids: Vec<String> = self.connection.select(q)?()?;
+        let batch_ids: Vec<String> = {
+            let connection = self.connection.lock().unwrap();
+            let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
+            connection.select(q)?()?
+        };
 
-        for batch_id in batch_ids {
+        for batch_id in &batch_ids {
             let batch_status = anthropic::batches::retrieve_batch(
                 self.http_client.as_ref(),
                 ANTHROPIC_API_URL,
@@ -360,9 +365,10 @@ impl BatchingLlmClient {
                     }
                 }
 
-                self.connection.with_savepoint("batch_download", || {
+                let connection = self.connection.lock().unwrap();
+                connection.with_savepoint("batch_download", || {
                     let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
-                    let mut exec = self.connection.exec_bound::<(&str, &str)>(q)?;
+                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
                     for (response_json, request_hash) in &updates {
                         exec((response_json.as_str(), request_hash.as_str()))?;
                     }
@@ -376,11 +382,13 @@ impl BatchingLlmClient {
     }
 
     async fn upload_pending_requests(&self) -> Result<String> {
-        let q = sql!(
-        SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
-        );
-
-        let rows: Vec<(String, String)> = self.connection.select(q)?()?;
+        let rows: Vec<(String, String)> = {
+            let connection = self.connection.lock().unwrap();
+            let q = sql!(
+            SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
+            );
+            connection.select(q)?()?
+        };
 
         if rows.is_empty() {
             return Ok(String::new());
@@ -440,10 +448,13 @@ impl BatchingLlmClient {
         .await
         .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 
-        let q = sql!(
-            UPDATE cache SET batch_id = ? WHERE batch_id is NULL
-        );
-        self.connection.exec_bound(q)?(batch.id.as_str())?;
+        {
+            let connection = self.connection.lock().unwrap();
+            let q = sql!(
+                UPDATE cache SET batch_id = ? WHERE batch_id is NULL
+            );
+            connection.exec_bound(q)?(batch.id.as_str())?;
+        }
 
         log::info!("Uploaded batch with {} requests", batch_len);
 

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -21,6 +21,8 @@ use std::{
     },
 };
 
+static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
+
 pub async fn run_prediction(
     example: &mut Example,
     provider: Option<PredictionProvider>,
@@ -233,12 +235,14 @@ async fn predict_anthropic(
 ) -> anyhow::Result<()> {
     let llm_model_name = "claude-sonnet-4-5";
     let max_tokens = 16384;
-    let llm_client = if batched {
-        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
-    } else {
-        AnthropicClient::plain()
-    };
-    let llm_client = llm_client.context("Failed to create LLM client")?;
+    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
+        let client = if batched {
+            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
+        } else {
+            AnthropicClient::plain()
+        };
+        client.expect("Failed to create Anthropic client")
+    });
 
     let prompt = example.prompt.as_ref().context("Prompt is required")?;
 
@@ -283,9 +287,10 @@ async fn predict_anthropic(
 pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
     match provider {
         PredictionProvider::Teacher => {
-            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
-            let llm_client =
-                AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
+            let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
+                AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
+                    .expect("Failed to create Anthropic client")
+            });
             llm_client
                 .sync_batches()
                 .await