diff --git a/crates/edit_prediction_cli/src/anthropic_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs index e7723926860d80e70cb3f3a40a47d1a9a0a490c7..0f7c9127e7388c37395b469cc3339bb4a7f53783 100644 --- a/crates/edit_prediction_cli/src/anthropic_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -270,9 +270,9 @@ impl BatchingLlmClient { Ok(None) } - /// Uploads pending requests as a new batch; downloads finished batches if any. + /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any. async fn sync_batches(&self) -> Result<()> { - self.upload_pending_requests().await?; + let _batch_ids = self.upload_pending_requests().await?; self.download_finished_batches().await } @@ -381,84 +381,114 @@ impl BatchingLlmClient { Ok(()) } - async fn upload_pending_requests(&self) -> Result { - let rows: Vec<(String, String)> = { - let connection = self.connection.lock().unwrap(); - let q = sql!( - SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL - ); - connection.select(q)?()? - }; + async fn upload_pending_requests(&self) -> Result> { + const BATCH_CHUNK_SIZE: i32 = 16_000; + let mut all_batch_ids = Vec::new(); + let mut total_uploaded = 0; - if rows.is_empty() { - return Ok(String::new()); - } + loop { + let rows: Vec<(String, String)> = { + let connection = self.connection.lock().unwrap(); + let q = sql!( + SELECT request_hash, request FROM cache + WHERE batch_id IS NULL AND response IS NULL + LIMIT ? + ); + connection.select_bound(q)?(BATCH_CHUNK_SIZE)? + }; - let batch_requests = rows - .iter() - .map(|(hash, request_str)| { - let serializable_request: SerializableRequest = - serde_json::from_str(&request_str).unwrap(); - - let messages: Vec = serializable_request - .messages - .into_iter() - .map(|msg| Message { - role: match msg.role.as_str() { - "user" => Role::User, - "assistant" => Role::Assistant, - _ => Role::User, - }, - content: vec![RequestContent::Text { - text: msg.content, - cache_control: None, - }], - }) - .collect(); - - let params = AnthropicRequest { - model: serializable_request.model, - max_tokens: serializable_request.max_tokens, - messages, - tools: Vec::new(), - thinking: None, - tool_choice: None, - system: None, - metadata: None, - stop_sequences: Vec::new(), - temperature: None, - top_k: None, - top_p: None, - }; - - let custom_id = format!("req_hash_{}", hash); - anthropic::batches::BatchRequest { custom_id, params } - }) - .collect::>(); + if rows.is_empty() { + break; + } - let batch_len = batch_requests.len(); - let batch = anthropic::batches::create_batch( - self.http_client.as_ref(), - ANTHROPIC_API_URL, - &self.api_key, - anthropic::batches::CreateBatchRequest { - requests: batch_requests, - }, - ) - .await - .map_err(|e| anyhow::anyhow!("{:?}", e))?; + let request_hashes: Vec = rows.iter().map(|(hash, _)| hash.clone()).collect(); + + let batch_requests = rows + .iter() + .map(|(hash, request_str)| { + let serializable_request: SerializableRequest = + serde_json::from_str(&request_str).unwrap(); + + let messages: Vec = serializable_request + .messages + .into_iter() + .map(|msg| Message { + role: match msg.role.as_str() { + "user" => Role::User, + "assistant" => Role::Assistant, + _ => Role::User, + }, + content: vec![RequestContent::Text { + text: msg.content, + cache_control: None, + }], + }) + .collect(); + + let params = AnthropicRequest { + model: serializable_request.model, + max_tokens: serializable_request.max_tokens, + messages, + tools: Vec::new(), + thinking: None, + tool_choice: None, + system: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, + }; + + let custom_id = format!("req_hash_{}", hash); + anthropic::batches::BatchRequest { custom_id, params } + }) + .collect::>(); + + let batch_len = batch_requests.len(); + let batch = anthropic::batches::create_batch( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + anthropic::batches::CreateBatchRequest { + requests: batch_requests, + }, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; - { - let connection = self.connection.lock().unwrap(); - let q = sql!( - UPDATE cache SET batch_id = ? WHERE batch_id is NULL + { + let connection = self.connection.lock().unwrap(); + connection.with_savepoint("batch_upload", || { + let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?); + let mut exec = connection.exec_bound::<(&str, &str)>(q)?; + for hash in &request_hashes { + exec((batch.id.as_str(), hash.as_str()))?; + } + Ok(()) + })?; + } + + total_uploaded += batch_len; + log::info!( + "Uploaded batch {} with {} requests ({} total)", + batch.id, + batch_len, + total_uploaded ); - connection.exec_bound(q)?(batch.id.as_str())?; + + all_batch_ids.push(batch.id); } - log::info!("Uploaded batch with {} requests", batch_len); + if !all_batch_ids.is_empty() { + log::info!( + "Finished uploading {} batches with {} total requests", + all_batch_ids.len(), + total_uploaded + ); + } - Ok(batch.id) + Ok(all_batch_ids) } fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {