1use anyhow::Result;
2use futures::AsyncReadExt;
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5
6use crate::{Request, RequestError, Response};
7
8/// A single request within a batch
9#[derive(Debug, Serialize, Deserialize)]
10pub struct BatchRequestItem {
11 pub custom_id: String,
12 pub method: String,
13 pub url: String,
14 pub body: Request,
15}
16
17impl BatchRequestItem {
18 pub fn new(custom_id: String, request: Request) -> Self {
19 Self {
20 custom_id,
21 method: "POST".to_string(),
22 url: "/v1/chat/completions".to_string(),
23 body: request,
24 }
25 }
26
27 pub fn to_jsonl_line(&self) -> Result<String, serde_json::Error> {
28 serde_json::to_string(self)
29 }
30}
31
32/// Request to create a batch
33#[derive(Debug, Serialize)]
34pub struct CreateBatchRequest {
35 pub input_file_id: String,
36 pub endpoint: String,
37 pub completion_window: String,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub metadata: Option<serde_json::Value>,
40}
41
42impl CreateBatchRequest {
43 pub fn new(input_file_id: String) -> Self {
44 Self {
45 input_file_id,
46 endpoint: "/v1/chat/completions".to_string(),
47 completion_window: "24h".to_string(),
48 metadata: None,
49 }
50 }
51}
52
53/// Response from batch creation or retrieval
54#[derive(Debug, Serialize, Deserialize)]
55pub struct Batch {
56 pub id: String,
57 pub object: String,
58 pub endpoint: String,
59 pub input_file_id: String,
60 pub completion_window: String,
61 pub status: String,
62 pub output_file_id: Option<String>,
63 pub error_file_id: Option<String>,
64 pub created_at: u64,
65 #[serde(default)]
66 pub in_progress_at: Option<u64>,
67 #[serde(default)]
68 pub expires_at: Option<u64>,
69 #[serde(default)]
70 pub finalizing_at: Option<u64>,
71 #[serde(default)]
72 pub completed_at: Option<u64>,
73 #[serde(default)]
74 pub failed_at: Option<u64>,
75 #[serde(default)]
76 pub expired_at: Option<u64>,
77 #[serde(default)]
78 pub cancelling_at: Option<u64>,
79 #[serde(default)]
80 pub cancelled_at: Option<u64>,
81 #[serde(default)]
82 pub request_counts: Option<BatchRequestCounts>,
83 #[serde(default)]
84 pub metadata: Option<serde_json::Value>,
85}
86
87#[derive(Debug, Serialize, Deserialize, Default)]
88pub struct BatchRequestCounts {
89 pub total: u64,
90 pub completed: u64,
91 pub failed: u64,
92}
93
94/// Response from file upload
95#[derive(Debug, Serialize, Deserialize)]
96pub struct FileObject {
97 pub id: String,
98 pub object: String,
99 pub bytes: u64,
100 pub created_at: u64,
101 pub filename: String,
102 pub purpose: String,
103}
104
105/// Individual result from batch output
106#[derive(Debug, Serialize, Deserialize)]
107pub struct BatchOutputItem {
108 pub id: String,
109 pub custom_id: String,
110 pub response: Option<BatchResponseBody>,
111 pub error: Option<BatchError>,
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct BatchResponseBody {
116 pub status_code: u16,
117 pub body: Response,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121pub struct BatchError {
122 pub code: String,
123 pub message: String,
124}
125
126/// Upload a JSONL file for batch processing
127pub async fn upload_batch_file(
128 client: &dyn HttpClient,
129 api_url: &str,
130 api_key: &str,
131 filename: &str,
132 content: Vec<u8>,
133) -> Result<FileObject, RequestError> {
134 let uri = format!("{api_url}/files");
135
136 let boundary = format!("----WebKitFormBoundary{:x}", rand::random::<u64>());
137
138 let mut body = Vec::new();
139 body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
140 body.extend_from_slice(b"Content-Disposition: form-data; name=\"purpose\"\r\n\r\n");
141 body.extend_from_slice(b"batch\r\n");
142 body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
143 body.extend_from_slice(
144 format!("Content-Disposition: form-data; name=\"file\"; filename=\"{filename}\"\r\n")
145 .as_bytes(),
146 );
147 body.extend_from_slice(b"Content-Type: application/jsonl\r\n\r\n");
148 body.extend_from_slice(&content);
149 body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes());
150
151 let request = HttpRequest::builder()
152 .method(Method::POST)
153 .uri(uri)
154 .header("Authorization", format!("Bearer {}", api_key.trim()))
155 .header(
156 "Content-Type",
157 format!("multipart/form-data; boundary={boundary}"),
158 )
159 .body(AsyncBody::from(body))
160 .map_err(|e| RequestError::Other(e.into()))?;
161
162 let mut response = client.send(request).await?;
163
164 if response.status().is_success() {
165 let mut body = String::new();
166 response
167 .body_mut()
168 .read_to_string(&mut body)
169 .await
170 .map_err(|e| RequestError::Other(e.into()))?;
171
172 serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
173 } else {
174 let mut body = String::new();
175 response
176 .body_mut()
177 .read_to_string(&mut body)
178 .await
179 .map_err(|e| RequestError::Other(e.into()))?;
180
181 Err(RequestError::HttpResponseError {
182 provider: "openai".to_owned(),
183 status_code: response.status(),
184 body,
185 headers: response.headers().clone(),
186 })
187 }
188}
189
190/// Create a batch from an uploaded file
191pub async fn create_batch(
192 client: &dyn HttpClient,
193 api_url: &str,
194 api_key: &str,
195 request: CreateBatchRequest,
196) -> Result<Batch, RequestError> {
197 let uri = format!("{api_url}/batches");
198
199 let serialized = serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?;
200
201 let request = HttpRequest::builder()
202 .method(Method::POST)
203 .uri(uri)
204 .header("Authorization", format!("Bearer {}", api_key.trim()))
205 .header("Content-Type", "application/json")
206 .body(AsyncBody::from(serialized))
207 .map_err(|e| RequestError::Other(e.into()))?;
208
209 let mut response = client.send(request).await?;
210
211 if response.status().is_success() {
212 let mut body = String::new();
213 response
214 .body_mut()
215 .read_to_string(&mut body)
216 .await
217 .map_err(|e| RequestError::Other(e.into()))?;
218
219 serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
220 } else {
221 let mut body = String::new();
222 response
223 .body_mut()
224 .read_to_string(&mut body)
225 .await
226 .map_err(|e| RequestError::Other(e.into()))?;
227
228 Err(RequestError::HttpResponseError {
229 provider: "openai".to_owned(),
230 status_code: response.status(),
231 body,
232 headers: response.headers().clone(),
233 })
234 }
235}
236
237/// Retrieve batch status
238pub async fn retrieve_batch(
239 client: &dyn HttpClient,
240 api_url: &str,
241 api_key: &str,
242 batch_id: &str,
243) -> Result<Batch, RequestError> {
244 let uri = format!("{api_url}/batches/{batch_id}");
245
246 let request = HttpRequest::builder()
247 .method(Method::GET)
248 .uri(uri)
249 .header("Authorization", format!("Bearer {}", api_key.trim()))
250 .body(AsyncBody::default())
251 .map_err(|e| RequestError::Other(e.into()))?;
252
253 let mut response = client.send(request).await?;
254
255 if response.status().is_success() {
256 let mut body = String::new();
257 response
258 .body_mut()
259 .read_to_string(&mut body)
260 .await
261 .map_err(|e| RequestError::Other(e.into()))?;
262
263 serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
264 } else {
265 let mut body = String::new();
266 response
267 .body_mut()
268 .read_to_string(&mut body)
269 .await
270 .map_err(|e| RequestError::Other(e.into()))?;
271
272 Err(RequestError::HttpResponseError {
273 provider: "openai".to_owned(),
274 status_code: response.status(),
275 body,
276 headers: response.headers().clone(),
277 })
278 }
279}
280
281/// Download file content (for batch results)
282pub async fn download_file(
283 client: &dyn HttpClient,
284 api_url: &str,
285 api_key: &str,
286 file_id: &str,
287) -> Result<String, RequestError> {
288 let uri = format!("{api_url}/files/{file_id}/content");
289
290 let request = HttpRequest::builder()
291 .method(Method::GET)
292 .uri(uri)
293 .header("Authorization", format!("Bearer {}", api_key.trim()))
294 .body(AsyncBody::default())
295 .map_err(|e| RequestError::Other(e.into()))?;
296
297 let mut response = client.send(request).await?;
298
299 if response.status().is_success() {
300 let mut body = String::new();
301 response
302 .body_mut()
303 .read_to_string(&mut body)
304 .await
305 .map_err(|e| RequestError::Other(e.into()))?;
306
307 Ok(body)
308 } else {
309 let mut body = String::new();
310 response
311 .body_mut()
312 .read_to_string(&mut body)
313 .await
314 .map_err(|e| RequestError::Other(e.into()))?;
315
316 Err(RequestError::HttpResponseError {
317 provider: "openai".to_owned(),
318 status_code: response.status(),
319 body,
320 headers: response.headers().clone(),
321 })
322 }
323}
324
325/// Parse batch output JSONL into individual results
326pub fn parse_batch_output(content: &str) -> Result<Vec<BatchOutputItem>, serde_json::Error> {
327 content
328 .lines()
329 .filter(|line| !line.trim().is_empty())
330 .map(|line| serde_json::from_str(line))
331 .collect()
332}