1use anyhow::Result;
2use futures::AsyncReadExt;
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5
6use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response};
7
8#[derive(Debug, Serialize, Deserialize)]
9pub struct BatchRequest {
10 pub custom_id: String,
11 pub params: Request,
12}
13
14#[derive(Debug, Serialize, Deserialize)]
15pub struct CreateBatchRequest {
16 pub requests: Vec<BatchRequest>,
17}
18
19#[derive(Debug, Serialize, Deserialize)]
20pub struct MessageBatchRequestCounts {
21 pub processing: u64,
22 pub succeeded: u64,
23 pub errored: u64,
24 pub canceled: u64,
25 pub expired: u64,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29pub struct MessageBatch {
30 pub id: String,
31 #[serde(rename = "type")]
32 pub batch_type: String,
33 pub processing_status: String,
34 pub request_counts: MessageBatchRequestCounts,
35 pub ended_at: Option<String>,
36 pub created_at: String,
37 pub expires_at: String,
38 pub archived_at: Option<String>,
39 pub cancel_initiated_at: Option<String>,
40 pub results_url: Option<String>,
41}
42
43#[derive(Debug, Serialize, Deserialize)]
44#[serde(tag = "type")]
45pub enum BatchResult {
46 #[serde(rename = "succeeded")]
47 Succeeded { message: Response },
48 #[serde(rename = "errored")]
49 Errored { error: BatchErrorResponse },
50 #[serde(rename = "canceled")]
51 Canceled,
52 #[serde(rename = "expired")]
53 Expired,
54}
55
56#[derive(Debug, Serialize, Deserialize)]
57pub struct BatchErrorResponse {
58 #[serde(rename = "type")]
59 pub response_type: String,
60 pub error: ApiError,
61}
62
63#[derive(Debug, Serialize, Deserialize)]
64pub struct BatchIndividualResponse {
65 pub custom_id: String,
66 pub result: BatchResult,
67}
68
69pub async fn create_batch(
70 client: &dyn HttpClient,
71 api_url: &str,
72 api_key: &str,
73 request: CreateBatchRequest,
74) -> Result<MessageBatch, AnthropicError> {
75 let uri = format!("{api_url}/v1/messages/batches");
76
77 let request_builder = HttpRequest::builder()
78 .method(Method::POST)
79 .uri(uri)
80 .header("Anthropic-Version", "2023-06-01")
81 .header("X-Api-Key", api_key.trim())
82 .header("Content-Type", "application/json");
83
84 let serialized_request =
85 serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
86 let http_request = request_builder
87 .body(AsyncBody::from(serialized_request))
88 .map_err(AnthropicError::BuildRequestBody)?;
89
90 let mut response = client
91 .send(http_request)
92 .await
93 .map_err(AnthropicError::HttpSend)?;
94
95 let rate_limits = RateLimitInfo::from_headers(response.headers());
96
97 if response.status().is_success() {
98 let mut body = String::new();
99 response
100 .body_mut()
101 .read_to_string(&mut body)
102 .await
103 .map_err(AnthropicError::ReadResponse)?;
104
105 serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
106 } else {
107 Err(crate::handle_error_response(response, rate_limits).await)
108 }
109}
110
111pub async fn retrieve_batch(
112 client: &dyn HttpClient,
113 api_url: &str,
114 api_key: &str,
115 message_batch_id: &str,
116) -> Result<MessageBatch, AnthropicError> {
117 let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}");
118
119 let request_builder = HttpRequest::builder()
120 .method(Method::GET)
121 .uri(uri)
122 .header("Anthropic-Version", "2023-06-01")
123 .header("X-Api-Key", api_key.trim());
124
125 let http_request = request_builder
126 .body(AsyncBody::default())
127 .map_err(AnthropicError::BuildRequestBody)?;
128
129 let mut response = client
130 .send(http_request)
131 .await
132 .map_err(AnthropicError::HttpSend)?;
133
134 let rate_limits = RateLimitInfo::from_headers(response.headers());
135
136 if response.status().is_success() {
137 let mut body = String::new();
138 response
139 .body_mut()
140 .read_to_string(&mut body)
141 .await
142 .map_err(AnthropicError::ReadResponse)?;
143
144 serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
145 } else {
146 Err(crate::handle_error_response(response, rate_limits).await)
147 }
148}
149
150pub async fn retrieve_batch_results(
151 client: &dyn HttpClient,
152 api_url: &str,
153 api_key: &str,
154 message_batch_id: &str,
155) -> Result<Vec<BatchIndividualResponse>, AnthropicError> {
156 let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results");
157
158 let request_builder = HttpRequest::builder()
159 .method(Method::GET)
160 .uri(uri)
161 .header("Anthropic-Version", "2023-06-01")
162 .header("X-Api-Key", api_key.trim());
163
164 let http_request = request_builder
165 .body(AsyncBody::default())
166 .map_err(AnthropicError::BuildRequestBody)?;
167
168 let mut response = client
169 .send(http_request)
170 .await
171 .map_err(AnthropicError::HttpSend)?;
172
173 let rate_limits = RateLimitInfo::from_headers(response.headers());
174
175 if response.status().is_success() {
176 let mut body = String::new();
177 response
178 .body_mut()
179 .read_to_string(&mut body)
180 .await
181 .map_err(AnthropicError::ReadResponse)?;
182
183 let mut results = Vec::new();
184 for line in body.lines() {
185 if line.trim().is_empty() {
186 continue;
187 }
188 let result: BatchIndividualResponse =
189 serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?;
190 results.push(result);
191 }
192
193 Ok(results)
194 } else {
195 Err(crate::handle_error_response(response, rate_limits).await)
196 }
197}