batches.rs

  1use anyhow::Result;
  2use futures::AsyncReadExt;
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5
  6use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response};
  7
  8#[derive(Debug, Serialize, Deserialize)]
  9pub struct BatchRequest {
 10    pub custom_id: String,
 11    pub params: Request,
 12}
 13
 14#[derive(Debug, Serialize, Deserialize)]
 15pub struct CreateBatchRequest {
 16    pub requests: Vec<BatchRequest>,
 17}
 18
 19#[derive(Debug, Serialize, Deserialize)]
 20pub struct MessageBatchRequestCounts {
 21    pub processing: u64,
 22    pub succeeded: u64,
 23    pub errored: u64,
 24    pub canceled: u64,
 25    pub expired: u64,
 26}
 27
 28#[derive(Debug, Serialize, Deserialize)]
 29pub struct MessageBatch {
 30    pub id: String,
 31    #[serde(rename = "type")]
 32    pub batch_type: String,
 33    pub processing_status: String,
 34    pub request_counts: MessageBatchRequestCounts,
 35    pub ended_at: Option<String>,
 36    pub created_at: String,
 37    pub expires_at: String,
 38    pub archived_at: Option<String>,
 39    pub cancel_initiated_at: Option<String>,
 40    pub results_url: Option<String>,
 41}
 42
 43#[derive(Debug, Serialize, Deserialize)]
 44#[serde(tag = "type")]
 45pub enum BatchResult {
 46    #[serde(rename = "succeeded")]
 47    Succeeded { message: Response },
 48    #[serde(rename = "errored")]
 49    Errored { error: ApiError },
 50    #[serde(rename = "canceled")]
 51    Canceled,
 52    #[serde(rename = "expired")]
 53    Expired,
 54}
 55
 56#[derive(Debug, Serialize, Deserialize)]
 57pub struct BatchIndividualResponse {
 58    pub custom_id: String,
 59    pub result: BatchResult,
 60}
 61
 62pub async fn create_batch(
 63    client: &dyn HttpClient,
 64    api_url: &str,
 65    api_key: &str,
 66    request: CreateBatchRequest,
 67) -> Result<MessageBatch, AnthropicError> {
 68    let uri = format!("{api_url}/v1/messages/batches");
 69
 70    let request_builder = HttpRequest::builder()
 71        .method(Method::POST)
 72        .uri(uri)
 73        .header("Anthropic-Version", "2023-06-01")
 74        .header("X-Api-Key", api_key.trim())
 75        .header("Content-Type", "application/json");
 76
 77    let serialized_request =
 78        serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
 79    let http_request = request_builder
 80        .body(AsyncBody::from(serialized_request))
 81        .map_err(AnthropicError::BuildRequestBody)?;
 82
 83    let mut response = client
 84        .send(http_request)
 85        .await
 86        .map_err(AnthropicError::HttpSend)?;
 87
 88    let rate_limits = RateLimitInfo::from_headers(response.headers());
 89
 90    if response.status().is_success() {
 91        let mut body = String::new();
 92        response
 93            .body_mut()
 94            .read_to_string(&mut body)
 95            .await
 96            .map_err(AnthropicError::ReadResponse)?;
 97
 98        serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
 99    } else {
100        Err(crate::handle_error_response(response, rate_limits).await)
101    }
102}
103
104pub async fn retrieve_batch(
105    client: &dyn HttpClient,
106    api_url: &str,
107    api_key: &str,
108    message_batch_id: &str,
109) -> Result<MessageBatch, AnthropicError> {
110    let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}");
111
112    let request_builder = HttpRequest::builder()
113        .method(Method::GET)
114        .uri(uri)
115        .header("Anthropic-Version", "2023-06-01")
116        .header("X-Api-Key", api_key.trim());
117
118    let http_request = request_builder
119        .body(AsyncBody::default())
120        .map_err(AnthropicError::BuildRequestBody)?;
121
122    let mut response = client
123        .send(http_request)
124        .await
125        .map_err(AnthropicError::HttpSend)?;
126
127    let rate_limits = RateLimitInfo::from_headers(response.headers());
128
129    if response.status().is_success() {
130        let mut body = String::new();
131        response
132            .body_mut()
133            .read_to_string(&mut body)
134            .await
135            .map_err(AnthropicError::ReadResponse)?;
136
137        serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
138    } else {
139        Err(crate::handle_error_response(response, rate_limits).await)
140    }
141}
142
143pub async fn retrieve_batch_results(
144    client: &dyn HttpClient,
145    api_url: &str,
146    api_key: &str,
147    message_batch_id: &str,
148) -> Result<Vec<BatchIndividualResponse>, AnthropicError> {
149    let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results");
150
151    let request_builder = HttpRequest::builder()
152        .method(Method::GET)
153        .uri(uri)
154        .header("Anthropic-Version", "2023-06-01")
155        .header("X-Api-Key", api_key.trim());
156
157    let http_request = request_builder
158        .body(AsyncBody::default())
159        .map_err(AnthropicError::BuildRequestBody)?;
160
161    let mut response = client
162        .send(http_request)
163        .await
164        .map_err(AnthropicError::HttpSend)?;
165
166    let rate_limits = RateLimitInfo::from_headers(response.headers());
167
168    if response.status().is_success() {
169        let mut body = String::new();
170        response
171            .body_mut()
172            .read_to_string(&mut body)
173            .await
174            .map_err(AnthropicError::ReadResponse)?;
175
176        let mut results = Vec::new();
177        for line in body.lines() {
178            if line.trim().is_empty() {
179                continue;
180            }
181            let result: BatchIndividualResponse =
182                serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?;
183            results.push(result);
184        }
185
186        Ok(results)
187    } else {
188        Err(crate::handle_error_response(response, rate_limits).await)
189    }
190}