Add `ep import-batch` to download finsihed Anthropic batches (#47364)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/anthropic_client.rs | 127 ++++++++++++++++
crates/edit_prediction_cli/src/main.rs             |  30 +++
2 files changed, 156 insertions(+), 1 deletion(-)

Detailed changes

crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -276,6 +276,121 @@ impl BatchingLlmClient {
         self.download_finished_batches().await
     }
 
+    /// Import batch results from external batch IDs (useful for recovering after database loss)
+    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
+        for batch_id in batch_ids {
+            log::info!("Importing batch {}", batch_id);
+
+            let batch_status = anthropic::batches::retrieve_batch(
+                self.http_client.as_ref(),
+                ANTHROPIC_API_URL,
+                &self.api_key,
+                batch_id,
+            )
+            .await
+            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
+
+            log::info!(
+                "Batch {} status: {}",
+                batch_id,
+                batch_status.processing_status
+            );
+
+            if batch_status.processing_status != "ended" {
+                log::warn!(
+                    "Batch {} is not finished (status: {}), skipping",
+                    batch_id,
+                    batch_status.processing_status
+                );
+                continue;
+            }
+
+            let results = anthropic::batches::retrieve_batch_results(
+                self.http_client.as_ref(),
+                ANTHROPIC_API_URL,
+                &self.api_key,
+                batch_id,
+            )
+            .await
+            .map_err(|e| {
+                anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
+            })?;
+
+            let mut updates: Vec<(String, String, String)> = Vec::new();
+            let mut success_count = 0;
+            let mut error_count = 0;
+
+            for result in results {
+                let request_hash = result
+                    .custom_id
+                    .strip_prefix("req_hash_")
+                    .unwrap_or(&result.custom_id)
+                    .to_string();
+
+                match result.result {
+                    anthropic::batches::BatchResult::Succeeded { message } => {
+                        let response_json = serde_json::to_string(&message)?;
+                        updates.push((request_hash, response_json, batch_id.clone()));
+                        success_count += 1;
+                    }
+                    anthropic::batches::BatchResult::Errored { error } => {
+                        log::error!(
+                            "Batch request {} failed: {}: {}",
+                            request_hash,
+                            error.error.error_type,
+                            error.error.message
+                        );
+                        let error_json = serde_json::json!({
+                            "error": {
+                                "type": error.error.error_type,
+                                "message": error.error.message
+                            }
+                        })
+                        .to_string();
+                        updates.push((request_hash, error_json, batch_id.clone()));
+                        error_count += 1;
+                    }
+                    anthropic::batches::BatchResult::Canceled => {
+                        log::warn!("Batch request {} was canceled", request_hash);
+                        error_count += 1;
+                    }
+                    anthropic::batches::BatchResult::Expired => {
+                        log::warn!("Batch request {} expired", request_hash);
+                        error_count += 1;
+                    }
+                }
+            }
+
+            let connection = self.connection.lock().unwrap();
+            connection.with_savepoint("batch_import", || {
+                // Use INSERT OR REPLACE to handle both new entries and updating existing ones
+                let q = sql!(
+                    INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
+                    VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
+                );
+                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
+                for (request_hash, response_json, batch_id) in &updates {
+                    exec((
+                        request_hash.as_str(),
+                        request_hash.as_str(),
+                        response_json.as_str(),
+                        batch_id.as_str(),
+                    ))?;
+                }
+                Ok(())
+            })?;
+
+            log::info!(
+                "Imported batch {}: {} successful, {} errors",
+                batch_id,
+                success_count,
+                error_count
+            );
+        }
+
+        Ok(())
+    }
+
     async fn download_finished_batches(&self) -> Result<()> {
         let batch_ids: Vec<String> = {
             let connection = self.connection.lock().unwrap();
@@ -585,4 +700,16 @@ impl AnthropicClient {
             AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
         }
     }
+
+    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
+        match self {
+            AnthropicClient::Plain(_) => {
+                anyhow::bail!("Import batches is only supported with batching client")
+            }
+            AnthropicClient::Batch(batching_llm_client) => {
+                batching_llm_client.import_batches(batch_ids).await
+            }
+            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+        }
+    }
 }

crates/edit_prediction_cli/src/main.rs 🔗

@@ -157,6 +157,8 @@ enum Command {
     Split(SplitArgs),
     /// Filter a JSONL dataset by programming language (based on cursor_path extension)
     FilterLanguages(FilterLanguagesArgs),
+    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
+    ImportBatch(ImportBatchArgs),
 }
 
 impl Display for Command {
@@ -189,6 +191,9 @@ impl Display for Command {
             Command::SplitCommit(_) => write!(f, "split-commit"),
             Command::Split(_) => write!(f, "split"),
             Command::FilterLanguages(_) => write!(f, "filter-languages"),
+            Command::ImportBatch(args) => {
+                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
+            }
         }
     }
 }
@@ -308,6 +313,13 @@ struct SynthesizeArgs {
     fresh: bool,
 }
 
+#[derive(Debug, Args, Clone)]
+struct ImportBatchArgs {
+    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
+    #[clap(long, required = true, num_args = 1..)]
+    batch_ids: Vec<String>,
+}
+
 impl EpArgs {
     fn output_path(&self) -> Option<PathBuf> {
         if self.in_place {
@@ -469,6 +481,21 @@ fn main() {
     };
 
     match &command {
+        Command::ImportBatch(import_args) => {
+            smol::block_on(async {
+                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
+                    .expect("Failed to create Anthropic client");
+                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
+                    eprintln!("Error importing batches: {:?}", e);
+                    std::process::exit(1);
+                }
+                println!(
+                    "Successfully imported {} batch(es)",
+                    import_args.batch_ids.len()
+                );
+            });
+            return;
+        }
         Command::Clean => {
             std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
             return;
@@ -672,7 +699,8 @@ fn main() {
                                         | Command::Synthesize(_)
                                         | Command::SplitCommit(_)
                                         | Command::Split(_)
-                                        | Command::FilterLanguages(_) => {
+                                        | Command::FilterLanguages(_)
+                                        | Command::ImportBatch(_) => {
                                             unreachable!()
                                         }
                                     }