batches.rs

  1use anyhow::Result;
  2use futures::AsyncReadExt;
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5
  6use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response};
  7
  8#[derive(Debug, Serialize, Deserialize)]
  9pub struct BatchRequest {
 10    pub custom_id: String,
 11    pub params: Request,
 12}
 13
 14#[derive(Debug, Serialize, Deserialize)]
 15pub struct CreateBatchRequest {
 16    pub requests: Vec<BatchRequest>,
 17}
 18
 19#[derive(Debug, Serialize, Deserialize)]
 20pub struct MessageBatchRequestCounts {
 21    pub processing: u64,
 22    pub succeeded: u64,
 23    pub errored: u64,
 24    pub canceled: u64,
 25    pub expired: u64,
 26}
 27
 28#[derive(Debug, Serialize, Deserialize)]
 29pub struct MessageBatch {
 30    pub id: String,
 31    #[serde(rename = "type")]
 32    pub batch_type: String,
 33    pub processing_status: String,
 34    pub request_counts: MessageBatchRequestCounts,
 35    pub ended_at: Option<String>,
 36    pub created_at: String,
 37    pub expires_at: String,
 38    pub archived_at: Option<String>,
 39    pub cancel_initiated_at: Option<String>,
 40    pub results_url: Option<String>,
 41}
 42
 43#[derive(Debug, Serialize, Deserialize)]
 44#[serde(tag = "type")]
 45pub enum BatchResult {
 46    #[serde(rename = "succeeded")]
 47    Succeeded { message: Response },
 48    #[serde(rename = "errored")]
 49    Errored { error: BatchErrorResponse },
 50    #[serde(rename = "canceled")]
 51    Canceled,
 52    #[serde(rename = "expired")]
 53    Expired,
 54}
 55
 56#[derive(Debug, Serialize, Deserialize)]
 57pub struct BatchErrorResponse {
 58    #[serde(rename = "type")]
 59    pub response_type: String,
 60    pub error: ApiError,
 61}
 62
 63#[derive(Debug, Serialize, Deserialize)]
 64pub struct BatchIndividualResponse {
 65    pub custom_id: String,
 66    pub result: BatchResult,
 67}
 68
 69pub async fn create_batch(
 70    client: &dyn HttpClient,
 71    api_url: &str,
 72    api_key: &str,
 73    request: CreateBatchRequest,
 74) -> Result<MessageBatch, AnthropicError> {
 75    let uri = format!("{api_url}/v1/messages/batches");
 76
 77    let request_builder = HttpRequest::builder()
 78        .method(Method::POST)
 79        .uri(uri)
 80        .header("Anthropic-Version", "2023-06-01")
 81        .header("X-Api-Key", api_key.trim())
 82        .header("Content-Type", "application/json");
 83
 84    let serialized_request =
 85        serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
 86    let http_request = request_builder
 87        .body(AsyncBody::from(serialized_request))
 88        .map_err(AnthropicError::BuildRequestBody)?;
 89
 90    let mut response = client
 91        .send(http_request)
 92        .await
 93        .map_err(AnthropicError::HttpSend)?;
 94
 95    let rate_limits = RateLimitInfo::from_headers(response.headers());
 96
 97    if response.status().is_success() {
 98        let mut body = String::new();
 99        response
100            .body_mut()
101            .read_to_string(&mut body)
102            .await
103            .map_err(AnthropicError::ReadResponse)?;
104
105        serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
106    } else {
107        Err(crate::handle_error_response(response, rate_limits).await)
108    }
109}
110
111pub async fn retrieve_batch(
112    client: &dyn HttpClient,
113    api_url: &str,
114    api_key: &str,
115    message_batch_id: &str,
116) -> Result<MessageBatch, AnthropicError> {
117    let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}");
118
119    let request_builder = HttpRequest::builder()
120        .method(Method::GET)
121        .uri(uri)
122        .header("Anthropic-Version", "2023-06-01")
123        .header("X-Api-Key", api_key.trim());
124
125    let http_request = request_builder
126        .body(AsyncBody::default())
127        .map_err(AnthropicError::BuildRequestBody)?;
128
129    let mut response = client
130        .send(http_request)
131        .await
132        .map_err(AnthropicError::HttpSend)?;
133
134    let rate_limits = RateLimitInfo::from_headers(response.headers());
135
136    if response.status().is_success() {
137        let mut body = String::new();
138        response
139            .body_mut()
140            .read_to_string(&mut body)
141            .await
142            .map_err(AnthropicError::ReadResponse)?;
143
144        serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
145    } else {
146        Err(crate::handle_error_response(response, rate_limits).await)
147    }
148}
149
150pub async fn retrieve_batch_results(
151    client: &dyn HttpClient,
152    api_url: &str,
153    api_key: &str,
154    message_batch_id: &str,
155) -> Result<Vec<BatchIndividualResponse>, AnthropicError> {
156    let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results");
157
158    let request_builder = HttpRequest::builder()
159        .method(Method::GET)
160        .uri(uri)
161        .header("Anthropic-Version", "2023-06-01")
162        .header("X-Api-Key", api_key.trim());
163
164    let http_request = request_builder
165        .body(AsyncBody::default())
166        .map_err(AnthropicError::BuildRequestBody)?;
167
168    let mut response = client
169        .send(http_request)
170        .await
171        .map_err(AnthropicError::HttpSend)?;
172
173    let rate_limits = RateLimitInfo::from_headers(response.headers());
174
175    if response.status().is_success() {
176        let mut body = String::new();
177        response
178            .body_mut()
179            .read_to_string(&mut body)
180            .await
181            .map_err(AnthropicError::ReadResponse)?;
182
183        let mut results = Vec::new();
184        for line in body.lines() {
185            if line.trim().is_empty() {
186                continue;
187            }
188            let result: BatchIndividualResponse =
189                serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?;
190            results.push(result);
191        }
192
193        Ok(results)
194    } else {
195        Err(crate::handle_error_response(response, rate_limits).await)
196    }
197}