1use anyhow::Result;
2use http_client::HttpClient;
3use indoc::indoc;
4use open_ai::{
5 MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
6 Response as OpenAiResponse, batches, non_streaming_completion,
7};
8use reqwest_client::ReqwestClient;
9use sqlez::bindable::Bind;
10use sqlez::bindable::StaticColumnCount;
11use sqlez_macros::sql;
12use std::hash::Hash;
13use std::hash::Hasher;
14use std::path::Path;
15use std::sync::{Arc, Mutex};
16
17pub struct PlainOpenAiClient {
18 pub http_client: Arc<dyn HttpClient>,
19 pub api_key: String,
20}
21
22impl PlainOpenAiClient {
23 pub fn new() -> Result<Self> {
24 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
25 let api_key = std::env::var("OPENAI_API_KEY")
26 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
27 Ok(Self {
28 http_client,
29 api_key,
30 })
31 }
32
33 pub async fn generate(
34 &self,
35 model: &str,
36 max_tokens: u64,
37 messages: Vec<RequestMessage>,
38 ) -> Result<OpenAiResponse> {
39 let request = OpenAiRequest {
40 model: model.to_string(),
41 messages,
42 stream: false,
43 max_completion_tokens: Some(max_tokens),
44 stop: Vec::new(),
45 temperature: None,
46 tool_choice: None,
47 parallel_tool_calls: None,
48 tools: Vec::new(),
49 prompt_cache_key: None,
50 reasoning_effort: None,
51 };
52
53 let response = non_streaming_completion(
54 self.http_client.as_ref(),
55 OPEN_AI_API_URL,
56 &self.api_key,
57 request,
58 )
59 .await
60 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
61
62 Ok(response)
63 }
64}
65
66pub struct BatchingOpenAiClient {
67 connection: Mutex<sqlez::connection::Connection>,
68 http_client: Arc<dyn HttpClient>,
69 api_key: String,
70}
71
72struct CacheRow {
73 request_hash: String,
74 request: Option<String>,
75 response: Option<String>,
76 batch_id: Option<String>,
77}
78
79impl StaticColumnCount for CacheRow {
80 fn column_count() -> usize {
81 4
82 }
83}
84
85impl Bind for CacheRow {
86 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
87 let next_index = statement.bind(&self.request_hash, start_index)?;
88 let next_index = statement.bind(&self.request, next_index)?;
89 let next_index = statement.bind(&self.response, next_index)?;
90 let next_index = statement.bind(&self.batch_id, next_index)?;
91 Ok(next_index)
92 }
93}
94
95#[derive(serde::Serialize, serde::Deserialize)]
96struct SerializableRequest {
97 model: String,
98 max_tokens: u64,
99 messages: Vec<SerializableMessage>,
100}
101
102#[derive(serde::Serialize, serde::Deserialize)]
103struct SerializableMessage {
104 role: String,
105 content: String,
106}
107
108impl BatchingOpenAiClient {
109 fn new(cache_path: &Path) -> Result<Self> {
110 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
111 let api_key = std::env::var("OPENAI_API_KEY")
112 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
113
114 let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
115 let mut statement = sqlez::statement::Statement::prepare(
116 &connection,
117 indoc! {"
118 CREATE TABLE IF NOT EXISTS openai_cache (
119 request_hash TEXT PRIMARY KEY,
120 request TEXT,
121 response TEXT,
122 batch_id TEXT
123 );
124 "},
125 )?;
126 statement.exec()?;
127 drop(statement);
128
129 Ok(Self {
130 connection: Mutex::new(connection),
131 http_client,
132 api_key,
133 })
134 }
135
136 pub fn lookup(
137 &self,
138 model: &str,
139 max_tokens: u64,
140 messages: &[RequestMessage],
141 seed: Option<usize>,
142 ) -> Result<Option<OpenAiResponse>> {
143 let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
144 let connection = self.connection.lock().unwrap();
145 let response: Vec<String> = connection.select_bound(
146 &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
147 )?(request_hash_str.as_str())?;
148 Ok(response
149 .into_iter()
150 .next()
151 .and_then(|text| serde_json::from_str(&text).ok()))
152 }
153
154 pub fn mark_for_batch(
155 &self,
156 model: &str,
157 max_tokens: u64,
158 messages: &[RequestMessage],
159 seed: Option<usize>,
160 ) -> Result<()> {
161 let request_hash = Self::request_hash(model, max_tokens, messages, seed);
162
163 let serializable_messages: Vec<SerializableMessage> = messages
164 .iter()
165 .map(|msg| SerializableMessage {
166 role: message_role_to_string(msg),
167 content: message_content_to_string(msg),
168 })
169 .collect();
170
171 let serializable_request = SerializableRequest {
172 model: model.to_string(),
173 max_tokens,
174 messages: serializable_messages,
175 };
176
177 let request = Some(serde_json::to_string(&serializable_request)?);
178 let cache_row = CacheRow {
179 request_hash,
180 request,
181 response: None,
182 batch_id: None,
183 };
184 let connection = self.connection.lock().unwrap();
185 connection.exec_bound::<CacheRow>(sql!(
186 INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
187 cache_row,
188 )
189 }
190
191 async fn generate(
192 &self,
193 model: &str,
194 max_tokens: u64,
195 messages: Vec<RequestMessage>,
196 seed: Option<usize>,
197 ) -> Result<Option<OpenAiResponse>> {
198 let response = self.lookup(model, max_tokens, &messages, seed)?;
199 if let Some(response) = response {
200 return Ok(Some(response));
201 }
202
203 self.mark_for_batch(model, max_tokens, &messages, seed)?;
204
205 Ok(None)
206 }
207
208 async fn sync_batches(&self) -> Result<()> {
209 let _batch_ids = self.upload_pending_requests().await?;
210 self.download_finished_batches().await
211 }
212
213 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
214 for batch_id in batch_ids {
215 log::info!("Importing OpenAI batch {}", batch_id);
216
217 let batch_status = batches::retrieve_batch(
218 self.http_client.as_ref(),
219 OPEN_AI_API_URL,
220 &self.api_key,
221 batch_id,
222 )
223 .await
224 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
225
226 log::info!("Batch {} status: {}", batch_id, batch_status.status);
227
228 if batch_status.status != "completed" {
229 log::warn!(
230 "Batch {} is not completed (status: {}), skipping",
231 batch_id,
232 batch_status.status
233 );
234 continue;
235 }
236
237 let output_file_id = batch_status.output_file_id.ok_or_else(|| {
238 anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
239 })?;
240
241 let results_content = batches::download_file(
242 self.http_client.as_ref(),
243 OPEN_AI_API_URL,
244 &self.api_key,
245 &output_file_id,
246 )
247 .await
248 .map_err(|e| {
249 anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
250 })?;
251
252 let results = batches::parse_batch_output(&results_content)
253 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
254
255 let mut updates: Vec<(String, String, String)> = Vec::new();
256 let mut success_count = 0;
257 let mut error_count = 0;
258
259 for result in results {
260 let request_hash = result
261 .custom_id
262 .strip_prefix("req_hash_")
263 .unwrap_or(&result.custom_id)
264 .to_string();
265
266 if let Some(response_body) = result.response {
267 if response_body.status_code == 200 {
268 let response_json = serde_json::to_string(&response_body.body)?;
269 updates.push((request_hash, response_json, batch_id.clone()));
270 success_count += 1;
271 } else {
272 log::error!(
273 "Batch request {} failed with status {}",
274 request_hash,
275 response_body.status_code
276 );
277 let error_json = serde_json::json!({
278 "error": {
279 "type": "http_error",
280 "status_code": response_body.status_code
281 }
282 })
283 .to_string();
284 updates.push((request_hash, error_json, batch_id.clone()));
285 error_count += 1;
286 }
287 } else if let Some(error) = result.error {
288 log::error!(
289 "Batch request {} failed: {}: {}",
290 request_hash,
291 error.code,
292 error.message
293 );
294 let error_json = serde_json::json!({
295 "error": {
296 "type": error.code,
297 "message": error.message
298 }
299 })
300 .to_string();
301 updates.push((request_hash, error_json, batch_id.clone()));
302 error_count += 1;
303 }
304 }
305
306 let connection = self.connection.lock().unwrap();
307 connection.with_savepoint("batch_import", || {
308 let q = sql!(
309 INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
310 VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
311 );
312 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
313 for (request_hash, response_json, batch_id) in &updates {
314 exec((
315 request_hash.as_str(),
316 request_hash.as_str(),
317 response_json.as_str(),
318 batch_id.as_str(),
319 ))?;
320 }
321 Ok(())
322 })?;
323
324 log::info!(
325 "Imported batch {}: {} successful, {} errors",
326 batch_id,
327 success_count,
328 error_count
329 );
330 }
331
332 Ok(())
333 }
334
335 async fn download_finished_batches(&self) -> Result<()> {
336 let batch_ids: Vec<String> = {
337 let connection = self.connection.lock().unwrap();
338 let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
339 connection.select(q)?()?
340 };
341
342 for batch_id in &batch_ids {
343 let batch_status = batches::retrieve_batch(
344 self.http_client.as_ref(),
345 OPEN_AI_API_URL,
346 &self.api_key,
347 batch_id,
348 )
349 .await
350 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
351
352 log::info!("Batch {} status: {}", batch_id, batch_status.status);
353
354 if batch_status.status == "completed" {
355 let output_file_id = match batch_status.output_file_id {
356 Some(id) => id,
357 None => {
358 log::warn!("Batch {} completed but has no output file", batch_id);
359 continue;
360 }
361 };
362
363 let results_content = batches::download_file(
364 self.http_client.as_ref(),
365 OPEN_AI_API_URL,
366 &self.api_key,
367 &output_file_id,
368 )
369 .await
370 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
371
372 let results = batches::parse_batch_output(&results_content)
373 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
374
375 let mut updates: Vec<(String, String)> = Vec::new();
376 let mut success_count = 0;
377
378 for result in results {
379 let request_hash = result
380 .custom_id
381 .strip_prefix("req_hash_")
382 .unwrap_or(&result.custom_id)
383 .to_string();
384
385 if let Some(response_body) = result.response {
386 if response_body.status_code == 200 {
387 let response_json = serde_json::to_string(&response_body.body)?;
388 updates.push((response_json, request_hash));
389 success_count += 1;
390 } else {
391 log::error!(
392 "Batch request {} failed with status {}",
393 request_hash,
394 response_body.status_code
395 );
396 let error_json = serde_json::json!({
397 "error": {
398 "type": "http_error",
399 "status_code": response_body.status_code
400 }
401 })
402 .to_string();
403 updates.push((error_json, request_hash));
404 }
405 } else if let Some(error) = result.error {
406 log::error!(
407 "Batch request {} failed: {}: {}",
408 request_hash,
409 error.code,
410 error.message
411 );
412 let error_json = serde_json::json!({
413 "error": {
414 "type": error.code,
415 "message": error.message
416 }
417 })
418 .to_string();
419 updates.push((error_json, request_hash));
420 }
421 }
422
423 let connection = self.connection.lock().unwrap();
424 connection.with_savepoint("batch_download", || {
425 let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
426 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
427 for (response_json, request_hash) in &updates {
428 exec((response_json.as_str(), request_hash.as_str()))?;
429 }
430 Ok(())
431 })?;
432 log::info!("Downloaded {} successful requests", success_count);
433 }
434 }
435
436 Ok(())
437 }
438
439 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
440 const BATCH_CHUNK_SIZE: i32 = 16_000;
441 let mut all_batch_ids = Vec::new();
442 let mut total_uploaded = 0;
443
444 loop {
445 let rows: Vec<(String, String)> = {
446 let connection = self.connection.lock().unwrap();
447 let q = sql!(
448 SELECT request_hash, request FROM openai_cache
449 WHERE batch_id IS NULL AND response IS NULL
450 LIMIT ?
451 );
452 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
453 };
454
455 if rows.is_empty() {
456 break;
457 }
458
459 let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
460
461 let mut jsonl_content = String::new();
462 for (hash, request_str) in &rows {
463 let serializable_request: SerializableRequest =
464 serde_json::from_str(request_str).unwrap();
465
466 let messages: Vec<RequestMessage> = serializable_request
467 .messages
468 .into_iter()
469 .map(|msg| match msg.role.as_str() {
470 "user" => RequestMessage::User {
471 content: MessageContent::Plain(msg.content),
472 },
473 "assistant" => RequestMessage::Assistant {
474 content: Some(MessageContent::Plain(msg.content)),
475 tool_calls: Vec::new(),
476 },
477 "system" => RequestMessage::System {
478 content: MessageContent::Plain(msg.content),
479 },
480 _ => RequestMessage::User {
481 content: MessageContent::Plain(msg.content),
482 },
483 })
484 .collect();
485
486 let request = OpenAiRequest {
487 model: serializable_request.model,
488 messages,
489 stream: false,
490 max_completion_tokens: Some(serializable_request.max_tokens),
491 stop: Vec::new(),
492 temperature: None,
493 tool_choice: None,
494 parallel_tool_calls: None,
495 tools: Vec::new(),
496 prompt_cache_key: None,
497 reasoning_effort: None,
498 };
499
500 let custom_id = format!("req_hash_{}", hash);
501 let batch_item = batches::BatchRequestItem::new(custom_id, request);
502 let line = batch_item
503 .to_jsonl_line()
504 .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
505 jsonl_content.push_str(&line);
506 jsonl_content.push('\n');
507 }
508
509 let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
510 let file_obj = batches::upload_batch_file(
511 self.http_client.as_ref(),
512 OPEN_AI_API_URL,
513 &self.api_key,
514 &filename,
515 jsonl_content.into_bytes(),
516 )
517 .await
518 .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
519
520 let batch = batches::create_batch(
521 self.http_client.as_ref(),
522 OPEN_AI_API_URL,
523 &self.api_key,
524 batches::CreateBatchRequest::new(file_obj.id),
525 )
526 .await
527 .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
528
529 {
530 let connection = self.connection.lock().unwrap();
531 connection.with_savepoint("batch_upload", || {
532 let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
533 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
534 for hash in &request_hashes {
535 exec((batch.id.as_str(), hash.as_str()))?;
536 }
537 Ok(())
538 })?;
539 }
540
541 let batch_len = rows.len();
542 total_uploaded += batch_len;
543 log::info!(
544 "Uploaded batch {} with {} requests ({} total)",
545 batch.id,
546 batch_len,
547 total_uploaded
548 );
549
550 all_batch_ids.push(batch.id);
551 }
552
553 if !all_batch_ids.is_empty() {
554 log::info!(
555 "Finished uploading {} batches with {} total requests",
556 all_batch_ids.len(),
557 total_uploaded
558 );
559 }
560
561 Ok(all_batch_ids)
562 }
563
564 fn request_hash(
565 model: &str,
566 max_tokens: u64,
567 messages: &[RequestMessage],
568 seed: Option<usize>,
569 ) -> String {
570 let mut hasher = std::hash::DefaultHasher::new();
571 "openai".hash(&mut hasher);
572 model.hash(&mut hasher);
573 max_tokens.hash(&mut hasher);
574 for msg in messages {
575 message_content_to_string(msg).hash(&mut hasher);
576 }
577 if let Some(seed) = seed {
578 seed.hash(&mut hasher);
579 }
580 let request_hash = hasher.finish();
581 format!("{request_hash:016x}")
582 }
583}
584
585fn message_role_to_string(msg: &RequestMessage) -> String {
586 match msg {
587 RequestMessage::User { .. } => "user".to_string(),
588 RequestMessage::Assistant { .. } => "assistant".to_string(),
589 RequestMessage::System { .. } => "system".to_string(),
590 RequestMessage::Tool { .. } => "tool".to_string(),
591 }
592}
593
594fn message_content_to_string(msg: &RequestMessage) -> String {
595 match msg {
596 RequestMessage::User { content } => content_to_string(content),
597 RequestMessage::Assistant { content, .. } => {
598 content.as_ref().map(content_to_string).unwrap_or_default()
599 }
600 RequestMessage::System { content } => content_to_string(content),
601 RequestMessage::Tool { content, .. } => content_to_string(content),
602 }
603}
604
605fn content_to_string(content: &MessageContent) -> String {
606 match content {
607 MessageContent::Plain(text) => text.clone(),
608 MessageContent::Multipart(parts) => parts
609 .iter()
610 .filter_map(|part| match part {
611 open_ai::MessagePart::Text { text } => Some(text.clone()),
612 _ => None,
613 })
614 .collect::<Vec<String>>()
615 .join("\n"),
616 }
617}
618
619pub enum OpenAiClient {
620 Plain(PlainOpenAiClient),
621 Batch(BatchingOpenAiClient),
622 #[allow(dead_code)]
623 Dummy,
624}
625
626impl OpenAiClient {
627 pub fn plain() -> Result<Self> {
628 Ok(Self::Plain(PlainOpenAiClient::new()?))
629 }
630
631 pub fn batch(cache_path: &Path) -> Result<Self> {
632 Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
633 }
634
635 #[allow(dead_code)]
636 pub fn dummy() -> Self {
637 Self::Dummy
638 }
639
640 pub async fn generate(
641 &self,
642 model: &str,
643 max_tokens: u64,
644 messages: Vec<RequestMessage>,
645 seed: Option<usize>,
646 ) -> Result<Option<OpenAiResponse>> {
647 match self {
648 OpenAiClient::Plain(plain_client) => plain_client
649 .generate(model, max_tokens, messages)
650 .await
651 .map(Some),
652 OpenAiClient::Batch(batching_client) => {
653 batching_client
654 .generate(model, max_tokens, messages, seed)
655 .await
656 }
657 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
658 }
659 }
660
661 pub async fn sync_batches(&self) -> Result<()> {
662 match self {
663 OpenAiClient::Plain(_) => Ok(()),
664 OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
665 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
666 }
667 }
668
669 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
670 match self {
671 OpenAiClient::Plain(_) => {
672 anyhow::bail!("Import batches is only supported with batching client")
673 }
674 OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
675 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
676 }
677 }
678}