ep_cli: Chunk teacher prediction requests to Anthropic batch API (#47318)

Ben Kunkle created

Closes #ISSUE

We're hitting request size limits on the new larger dataset, so we need
to do pre-batching. A 16k chunk size was chosen arbitrarily but seems to
be a good tradeoff between number of chunks / within request size limit

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/edit_prediction_cli/src/anthropic_client.rs | 174 +++++++++------
1 file changed, 102 insertions(+), 72 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -270,9 +270,9 @@ impl BatchingLlmClient {
         Ok(None)
     }
 
-    /// Uploads pending requests as a new batch; downloads finished batches if any.
+    /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
     async fn sync_batches(&self) -> Result<()> {
-        self.upload_pending_requests().await?;
+        let _batch_ids = self.upload_pending_requests().await?;
         self.download_finished_batches().await
     }
 
@@ -381,84 +381,114 @@ impl BatchingLlmClient {
         Ok(())
     }
 
-    async fn upload_pending_requests(&self) -> Result<String> {
-        let rows: Vec<(String, String)> = {
-            let connection = self.connection.lock().unwrap();
-            let q = sql!(
-            SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
-            );
-            connection.select(q)?()?
-        };
+    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
+        const BATCH_CHUNK_SIZE: i32 = 16_000;
+        let mut all_batch_ids = Vec::new();
+        let mut total_uploaded = 0;
 
-        if rows.is_empty() {
-            return Ok(String::new());
-        }
+        loop {
+            let rows: Vec<(String, String)> = {
+                let connection = self.connection.lock().unwrap();
+                let q = sql!(
+                    SELECT request_hash, request FROM cache
+                    WHERE batch_id IS NULL AND response IS NULL
+                    LIMIT ?
+                );
+                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
+            };
 
-        let batch_requests = rows
-            .iter()
-            .map(|(hash, request_str)| {
-                let serializable_request: SerializableRequest =
-                    serde_json::from_str(&request_str).unwrap();
-
-                let messages: Vec<Message> = serializable_request
-                    .messages
-                    .into_iter()
-                    .map(|msg| Message {
-                        role: match msg.role.as_str() {
-                            "user" => Role::User,
-                            "assistant" => Role::Assistant,
-                            _ => Role::User,
-                        },
-                        content: vec![RequestContent::Text {
-                            text: msg.content,
-                            cache_control: None,
-                        }],
-                    })
-                    .collect();
-
-                let params = AnthropicRequest {
-                    model: serializable_request.model,
-                    max_tokens: serializable_request.max_tokens,
-                    messages,
-                    tools: Vec::new(),
-                    thinking: None,
-                    tool_choice: None,
-                    system: None,
-                    metadata: None,
-                    stop_sequences: Vec::new(),
-                    temperature: None,
-                    top_k: None,
-                    top_p: None,
-                };
-
-                let custom_id = format!("req_hash_{}", hash);
-                anthropic::batches::BatchRequest { custom_id, params }
-            })
-            .collect::<Vec<_>>();
+            if rows.is_empty() {
+                break;
+            }
 
-        let batch_len = batch_requests.len();
-        let batch = anthropic::batches::create_batch(
-            self.http_client.as_ref(),
-            ANTHROPIC_API_URL,
-            &self.api_key,
-            anthropic::batches::CreateBatchRequest {
-                requests: batch_requests,
-            },
-        )
-        .await
-        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
+
+            let batch_requests = rows
+                .iter()
+                .map(|(hash, request_str)| {
+                    let serializable_request: SerializableRequest =
+                        serde_json::from_str(&request_str).unwrap();
+
+                    let messages: Vec<Message> = serializable_request
+                        .messages
+                        .into_iter()
+                        .map(|msg| Message {
+                            role: match msg.role.as_str() {
+                                "user" => Role::User,
+                                "assistant" => Role::Assistant,
+                                _ => Role::User,
+                            },
+                            content: vec![RequestContent::Text {
+                                text: msg.content,
+                                cache_control: None,
+                            }],
+                        })
+                        .collect();
+
+                    let params = AnthropicRequest {
+                        model: serializable_request.model,
+                        max_tokens: serializable_request.max_tokens,
+                        messages,
+                        tools: Vec::new(),
+                        thinking: None,
+                        tool_choice: None,
+                        system: None,
+                        metadata: None,
+                        stop_sequences: Vec::new(),
+                        temperature: None,
+                        top_k: None,
+                        top_p: None,
+                    };
+
+                    let custom_id = format!("req_hash_{}", hash);
+                    anthropic::batches::BatchRequest { custom_id, params }
+                })
+                .collect::<Vec<_>>();
+
+            let batch_len = batch_requests.len();
+            let batch = anthropic::batches::create_batch(
+                self.http_client.as_ref(),
+                ANTHROPIC_API_URL,
+                &self.api_key,
+                anthropic::batches::CreateBatchRequest {
+                    requests: batch_requests,
+                },
+            )
+            .await
+            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 
-        {
-            let connection = self.connection.lock().unwrap();
-            let q = sql!(
-                UPDATE cache SET batch_id = ? WHERE batch_id is NULL
+            {
+                let connection = self.connection.lock().unwrap();
+                connection.with_savepoint("batch_upload", || {
+                    let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
+                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
+                    for hash in &request_hashes {
+                        exec((batch.id.as_str(), hash.as_str()))?;
+                    }
+                    Ok(())
+                })?;
+            }
+
+            total_uploaded += batch_len;
+            log::info!(
+                "Uploaded batch {} with {} requests ({} total)",
+                batch.id,
+                batch_len,
+                total_uploaded
             );
-            connection.exec_bound(q)?(batch.id.as_str())?;
+
+            all_batch_ids.push(batch.id);
         }
 
-        log::info!("Uploaded batch with {} requests", batch_len);
+        if !all_batch_ids.is_empty() {
+            log::info!(
+                "Finished uploading {} batches with {} total requests",
+                all_batch_ids.len(),
+                total_uploaded
+            );
+        }
 
-        Ok(batch.id)
+        Ok(all_batch_ids)
     }
 
     fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {