batches.rs

  1use anyhow::Result;
  2use futures::AsyncReadExt;
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5
  6use crate::{Request, RequestError, Response};
  7
  8/// A single request within a batch
  9#[derive(Debug, Serialize, Deserialize)]
 10pub struct BatchRequestItem {
 11    pub custom_id: String,
 12    pub method: String,
 13    pub url: String,
 14    pub body: Request,
 15}
 16
 17impl BatchRequestItem {
 18    pub fn new(custom_id: String, request: Request) -> Self {
 19        Self {
 20            custom_id,
 21            method: "POST".to_string(),
 22            url: "/v1/chat/completions".to_string(),
 23            body: request,
 24        }
 25    }
 26
 27    pub fn to_jsonl_line(&self) -> Result<String, serde_json::Error> {
 28        serde_json::to_string(self)
 29    }
 30}
 31
 32/// Request to create a batch
 33#[derive(Debug, Serialize)]
 34pub struct CreateBatchRequest {
 35    pub input_file_id: String,
 36    pub endpoint: String,
 37    pub completion_window: String,
 38    #[serde(skip_serializing_if = "Option::is_none")]
 39    pub metadata: Option<serde_json::Value>,
 40}
 41
 42impl CreateBatchRequest {
 43    pub fn new(input_file_id: String) -> Self {
 44        Self {
 45            input_file_id,
 46            endpoint: "/v1/chat/completions".to_string(),
 47            completion_window: "24h".to_string(),
 48            metadata: None,
 49        }
 50    }
 51}
 52
 53/// Response from batch creation or retrieval
 54#[derive(Debug, Serialize, Deserialize)]
 55pub struct Batch {
 56    pub id: String,
 57    pub object: String,
 58    pub endpoint: String,
 59    pub input_file_id: String,
 60    pub completion_window: String,
 61    pub status: String,
 62    pub output_file_id: Option<String>,
 63    pub error_file_id: Option<String>,
 64    pub created_at: u64,
 65    #[serde(default)]
 66    pub in_progress_at: Option<u64>,
 67    #[serde(default)]
 68    pub expires_at: Option<u64>,
 69    #[serde(default)]
 70    pub finalizing_at: Option<u64>,
 71    #[serde(default)]
 72    pub completed_at: Option<u64>,
 73    #[serde(default)]
 74    pub failed_at: Option<u64>,
 75    #[serde(default)]
 76    pub expired_at: Option<u64>,
 77    #[serde(default)]
 78    pub cancelling_at: Option<u64>,
 79    #[serde(default)]
 80    pub cancelled_at: Option<u64>,
 81    #[serde(default)]
 82    pub request_counts: Option<BatchRequestCounts>,
 83    #[serde(default)]
 84    pub metadata: Option<serde_json::Value>,
 85}
 86
 87#[derive(Debug, Serialize, Deserialize, Default)]
 88pub struct BatchRequestCounts {
 89    pub total: u64,
 90    pub completed: u64,
 91    pub failed: u64,
 92}
 93
 94/// Response from file upload
 95#[derive(Debug, Serialize, Deserialize)]
 96pub struct FileObject {
 97    pub id: String,
 98    pub object: String,
 99    pub bytes: u64,
100    pub created_at: u64,
101    pub filename: String,
102    pub purpose: String,
103}
104
105/// Individual result from batch output
106#[derive(Debug, Serialize, Deserialize)]
107pub struct BatchOutputItem {
108    pub id: String,
109    pub custom_id: String,
110    pub response: Option<BatchResponseBody>,
111    pub error: Option<BatchError>,
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct BatchResponseBody {
116    pub status_code: u16,
117    pub body: Response,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121pub struct BatchError {
122    pub code: String,
123    pub message: String,
124}
125
126/// Upload a JSONL file for batch processing
127pub async fn upload_batch_file(
128    client: &dyn HttpClient,
129    api_url: &str,
130    api_key: &str,
131    filename: &str,
132    content: Vec<u8>,
133) -> Result<FileObject, RequestError> {
134    let uri = format!("{api_url}/files");
135
136    let boundary = format!("----WebKitFormBoundary{:x}", rand::random::<u64>());
137
138    let mut body = Vec::new();
139    body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
140    body.extend_from_slice(b"Content-Disposition: form-data; name=\"purpose\"\r\n\r\n");
141    body.extend_from_slice(b"batch\r\n");
142    body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
143    body.extend_from_slice(
144        format!("Content-Disposition: form-data; name=\"file\"; filename=\"{filename}\"\r\n")
145            .as_bytes(),
146    );
147    body.extend_from_slice(b"Content-Type: application/jsonl\r\n\r\n");
148    body.extend_from_slice(&content);
149    body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes());
150
151    let request = HttpRequest::builder()
152        .method(Method::POST)
153        .uri(uri)
154        .header("Authorization", format!("Bearer {}", api_key.trim()))
155        .header(
156            "Content-Type",
157            format!("multipart/form-data; boundary={boundary}"),
158        )
159        .body(AsyncBody::from(body))
160        .map_err(|e| RequestError::Other(e.into()))?;
161
162    let mut response = client.send(request).await?;
163
164    if response.status().is_success() {
165        let mut body = String::new();
166        response
167            .body_mut()
168            .read_to_string(&mut body)
169            .await
170            .map_err(|e| RequestError::Other(e.into()))?;
171
172        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
173    } else {
174        let mut body = String::new();
175        response
176            .body_mut()
177            .read_to_string(&mut body)
178            .await
179            .map_err(|e| RequestError::Other(e.into()))?;
180
181        Err(RequestError::HttpResponseError {
182            provider: "openai".to_owned(),
183            status_code: response.status(),
184            body,
185            headers: response.headers().clone(),
186        })
187    }
188}
189
190/// Create a batch from an uploaded file
191pub async fn create_batch(
192    client: &dyn HttpClient,
193    api_url: &str,
194    api_key: &str,
195    request: CreateBatchRequest,
196) -> Result<Batch, RequestError> {
197    let uri = format!("{api_url}/batches");
198
199    let serialized = serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?;
200
201    let request = HttpRequest::builder()
202        .method(Method::POST)
203        .uri(uri)
204        .header("Authorization", format!("Bearer {}", api_key.trim()))
205        .header("Content-Type", "application/json")
206        .body(AsyncBody::from(serialized))
207        .map_err(|e| RequestError::Other(e.into()))?;
208
209    let mut response = client.send(request).await?;
210
211    if response.status().is_success() {
212        let mut body = String::new();
213        response
214            .body_mut()
215            .read_to_string(&mut body)
216            .await
217            .map_err(|e| RequestError::Other(e.into()))?;
218
219        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
220    } else {
221        let mut body = String::new();
222        response
223            .body_mut()
224            .read_to_string(&mut body)
225            .await
226            .map_err(|e| RequestError::Other(e.into()))?;
227
228        Err(RequestError::HttpResponseError {
229            provider: "openai".to_owned(),
230            status_code: response.status(),
231            body,
232            headers: response.headers().clone(),
233        })
234    }
235}
236
237/// Retrieve batch status
238pub async fn retrieve_batch(
239    client: &dyn HttpClient,
240    api_url: &str,
241    api_key: &str,
242    batch_id: &str,
243) -> Result<Batch, RequestError> {
244    let uri = format!("{api_url}/batches/{batch_id}");
245
246    let request = HttpRequest::builder()
247        .method(Method::GET)
248        .uri(uri)
249        .header("Authorization", format!("Bearer {}", api_key.trim()))
250        .body(AsyncBody::default())
251        .map_err(|e| RequestError::Other(e.into()))?;
252
253    let mut response = client.send(request).await?;
254
255    if response.status().is_success() {
256        let mut body = String::new();
257        response
258            .body_mut()
259            .read_to_string(&mut body)
260            .await
261            .map_err(|e| RequestError::Other(e.into()))?;
262
263        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
264    } else {
265        let mut body = String::new();
266        response
267            .body_mut()
268            .read_to_string(&mut body)
269            .await
270            .map_err(|e| RequestError::Other(e.into()))?;
271
272        Err(RequestError::HttpResponseError {
273            provider: "openai".to_owned(),
274            status_code: response.status(),
275            body,
276            headers: response.headers().clone(),
277        })
278    }
279}
280
281/// Download file content (for batch results)
282pub async fn download_file(
283    client: &dyn HttpClient,
284    api_url: &str,
285    api_key: &str,
286    file_id: &str,
287) -> Result<String, RequestError> {
288    let uri = format!("{api_url}/files/{file_id}/content");
289
290    let request = HttpRequest::builder()
291        .method(Method::GET)
292        .uri(uri)
293        .header("Authorization", format!("Bearer {}", api_key.trim()))
294        .body(AsyncBody::default())
295        .map_err(|e| RequestError::Other(e.into()))?;
296
297    let mut response = client.send(request).await?;
298
299    if response.status().is_success() {
300        let mut body = String::new();
301        response
302            .body_mut()
303            .read_to_string(&mut body)
304            .await
305            .map_err(|e| RequestError::Other(e.into()))?;
306
307        Ok(body)
308    } else {
309        let mut body = String::new();
310        response
311            .body_mut()
312            .read_to_string(&mut body)
313            .await
314            .map_err(|e| RequestError::Other(e.into()))?;
315
316        Err(RequestError::HttpResponseError {
317            provider: "openai".to_owned(),
318            status_code: response.status(),
319            body,
320            headers: response.headers().clone(),
321        })
322    }
323}
324
325/// Parse batch output JSONL into individual results
326pub fn parse_batch_output(content: &str) -> Result<Vec<BatchOutputItem>, serde_json::Error> {
327    content
328        .lines()
329        .filter(|line| !line.trim().is_empty())
330        .map(|line| serde_json::from_str(line))
331        .collect()
332}