@@ -270,9 +270,9 @@ impl BatchingLlmClient {
Ok(None)
}
- /// Uploads pending requests as a new batch; downloads finished batches if any.
+ /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
async fn sync_batches(&self) -> Result<()> {
- self.upload_pending_requests().await?;
+ let _batch_ids = self.upload_pending_requests().await?;
self.download_finished_batches().await
}
@@ -381,84 +381,114 @@ impl BatchingLlmClient {
Ok(())
}
- async fn upload_pending_requests(&self) -> Result<String> {
- let rows: Vec<(String, String)> = {
- let connection = self.connection.lock().unwrap();
- let q = sql!(
- SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
- );
- connection.select(q)?()?
- };
+ async fn upload_pending_requests(&self) -> Result<Vec<String>> {
+ const BATCH_CHUNK_SIZE: i32 = 16_000;
+ let mut all_batch_ids = Vec::new();
+ let mut total_uploaded = 0;
- if rows.is_empty() {
- return Ok(String::new());
- }
+ loop {
+ let rows: Vec<(String, String)> = {
+ let connection = self.connection.lock().unwrap();
+ let q = sql!(
+ SELECT request_hash, request FROM cache
+ WHERE batch_id IS NULL AND response IS NULL
+ LIMIT ?
+ );
+ connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
+ };
- let batch_requests = rows
- .iter()
- .map(|(hash, request_str)| {
- let serializable_request: SerializableRequest =
- serde_json::from_str(&request_str).unwrap();
-
- let messages: Vec<Message> = serializable_request
- .messages
- .into_iter()
- .map(|msg| Message {
- role: match msg.role.as_str() {
- "user" => Role::User,
- "assistant" => Role::Assistant,
- _ => Role::User,
- },
- content: vec![RequestContent::Text {
- text: msg.content,
- cache_control: None,
- }],
- })
- .collect();
-
- let params = AnthropicRequest {
- model: serializable_request.model,
- max_tokens: serializable_request.max_tokens,
- messages,
- tools: Vec::new(),
- thinking: None,
- tool_choice: None,
- system: None,
- metadata: None,
- stop_sequences: Vec::new(),
- temperature: None,
- top_k: None,
- top_p: None,
- };
-
- let custom_id = format!("req_hash_{}", hash);
- anthropic::batches::BatchRequest { custom_id, params }
- })
- .collect::<Vec<_>>();
+ if rows.is_empty() {
+ break;
+ }
- let batch_len = batch_requests.len();
- let batch = anthropic::batches::create_batch(
- self.http_client.as_ref(),
- ANTHROPIC_API_URL,
- &self.api_key,
- anthropic::batches::CreateBatchRequest {
- requests: batch_requests,
- },
- )
- .await
- .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+ let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
+
+ let batch_requests = rows
+ .iter()
+ .map(|(hash, request_str)| {
+ let serializable_request: SerializableRequest =
+ serde_json::from_str(&request_str).unwrap();
+
+ let messages: Vec<Message> = serializable_request
+ .messages
+ .into_iter()
+ .map(|msg| Message {
+ role: match msg.role.as_str() {
+ "user" => Role::User,
+ "assistant" => Role::Assistant,
+ _ => Role::User,
+ },
+ content: vec![RequestContent::Text {
+ text: msg.content,
+ cache_control: None,
+ }],
+ })
+ .collect();
+
+ let params = AnthropicRequest {
+ model: serializable_request.model,
+ max_tokens: serializable_request.max_tokens,
+ messages,
+ tools: Vec::new(),
+ thinking: None,
+ tool_choice: None,
+ system: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: None,
+ top_k: None,
+ top_p: None,
+ };
+
+ let custom_id = format!("req_hash_{}", hash);
+ anthropic::batches::BatchRequest { custom_id, params }
+ })
+ .collect::<Vec<_>>();
+
+ let batch_len = batch_requests.len();
+ let batch = anthropic::batches::create_batch(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ anthropic::batches::CreateBatchRequest {
+ requests: batch_requests,
+ },
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
- {
- let connection = self.connection.lock().unwrap();
- let q = sql!(
- UPDATE cache SET batch_id = ? WHERE batch_id is NULL
+ {
+ let connection = self.connection.lock().unwrap();
+ connection.with_savepoint("batch_upload", || {
+ let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
+ let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
+ for hash in &request_hashes {
+ exec((batch.id.as_str(), hash.as_str()))?;
+ }
+ Ok(())
+ })?;
+ }
+
+ total_uploaded += batch_len;
+ log::info!(
+ "Uploaded batch {} with {} requests ({} total)",
+ batch.id,
+ batch_len,
+ total_uploaded
);
- connection.exec_bound(q)?(batch.id.as_str())?;
+
+ all_batch_ids.push(batch.id);
}
- log::info!("Uploaded batch with {} requests", batch_len);
+ if !all_batch_ids.is_empty() {
+ log::info!(
+ "Finished uploading {} batches with {} total requests",
+ all_batch_ids.len(),
+ total_uploaded
+ );
+ }
- Ok(batch.id)
+ Ok(all_batch_ids)
}
fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {