1use anyhow::Result;
2use http_client::HttpClient;
3use indoc::indoc;
4use open_ai::{
5 MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
6 Response as OpenAiResponse, batches, non_streaming_completion,
7};
8use reqwest_client::ReqwestClient;
9use sqlez::bindable::Bind;
10use sqlez::bindable::StaticColumnCount;
11use sqlez_macros::sql;
12use std::hash::Hash;
13use std::hash::Hasher;
14use std::path::Path;
15use std::sync::{Arc, Mutex};
16
17pub struct PlainOpenAiClient {
18 pub http_client: Arc<dyn HttpClient>,
19 pub api_key: String,
20}
21
22impl PlainOpenAiClient {
23 pub fn new() -> Result<Self> {
24 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
25 let api_key = std::env::var("OPENAI_API_KEY")
26 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
27 Ok(Self {
28 http_client,
29 api_key,
30 })
31 }
32
33 pub async fn generate(
34 &self,
35 model: &str,
36 max_tokens: u64,
37 messages: Vec<RequestMessage>,
38 ) -> Result<OpenAiResponse> {
39 let request = OpenAiRequest {
40 model: model.to_string(),
41 messages,
42 stream: false,
43 stream_options: None,
44 max_completion_tokens: Some(max_tokens),
45 stop: Vec::new(),
46 temperature: None,
47 tool_choice: None,
48 parallel_tool_calls: None,
49 tools: Vec::new(),
50 prompt_cache_key: None,
51 reasoning_effort: None,
52 };
53
54 let response = non_streaming_completion(
55 self.http_client.as_ref(),
56 OPEN_AI_API_URL,
57 &self.api_key,
58 request,
59 )
60 .await
61 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
62
63 Ok(response)
64 }
65}
66
67pub struct BatchingOpenAiClient {
68 connection: Mutex<sqlez::connection::Connection>,
69 http_client: Arc<dyn HttpClient>,
70 api_key: String,
71}
72
73struct CacheRow {
74 request_hash: String,
75 request: Option<String>,
76 response: Option<String>,
77 batch_id: Option<String>,
78}
79
80impl StaticColumnCount for CacheRow {
81 fn column_count() -> usize {
82 4
83 }
84}
85
86impl Bind for CacheRow {
87 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
88 let next_index = statement.bind(&self.request_hash, start_index)?;
89 let next_index = statement.bind(&self.request, next_index)?;
90 let next_index = statement.bind(&self.response, next_index)?;
91 let next_index = statement.bind(&self.batch_id, next_index)?;
92 Ok(next_index)
93 }
94}
95
96#[derive(serde::Serialize, serde::Deserialize)]
97struct SerializableRequest {
98 model: String,
99 max_tokens: u64,
100 messages: Vec<SerializableMessage>,
101}
102
103#[derive(serde::Serialize, serde::Deserialize)]
104struct SerializableMessage {
105 role: String,
106 content: String,
107}
108
109impl BatchingOpenAiClient {
110 fn new(cache_path: &Path) -> Result<Self> {
111 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
112 let api_key = std::env::var("OPENAI_API_KEY")
113 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
114
115 let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
116 let mut statement = sqlez::statement::Statement::prepare(
117 &connection,
118 indoc! {"
119 CREATE TABLE IF NOT EXISTS openai_cache (
120 request_hash TEXT PRIMARY KEY,
121 request TEXT,
122 response TEXT,
123 batch_id TEXT
124 );
125 "},
126 )?;
127 statement.exec()?;
128 drop(statement);
129
130 Ok(Self {
131 connection: Mutex::new(connection),
132 http_client,
133 api_key,
134 })
135 }
136
137 pub fn lookup(
138 &self,
139 model: &str,
140 max_tokens: u64,
141 messages: &[RequestMessage],
142 seed: Option<usize>,
143 ) -> Result<Option<OpenAiResponse>> {
144 let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
145 let connection = self.connection.lock().unwrap();
146 let response: Vec<String> = connection.select_bound(
147 &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
148 )?(request_hash_str.as_str())?;
149 Ok(response
150 .into_iter()
151 .next()
152 .and_then(|text| serde_json::from_str(&text).ok()))
153 }
154
155 pub fn mark_for_batch(
156 &self,
157 model: &str,
158 max_tokens: u64,
159 messages: &[RequestMessage],
160 seed: Option<usize>,
161 ) -> Result<()> {
162 let request_hash = Self::request_hash(model, max_tokens, messages, seed);
163
164 let serializable_messages: Vec<SerializableMessage> = messages
165 .iter()
166 .map(|msg| SerializableMessage {
167 role: message_role_to_string(msg),
168 content: message_content_to_string(msg),
169 })
170 .collect();
171
172 let serializable_request = SerializableRequest {
173 model: model.to_string(),
174 max_tokens,
175 messages: serializable_messages,
176 };
177
178 let request = Some(serde_json::to_string(&serializable_request)?);
179 let cache_row = CacheRow {
180 request_hash,
181 request,
182 response: None,
183 batch_id: None,
184 };
185 let connection = self.connection.lock().unwrap();
186 connection.exec_bound::<CacheRow>(sql!(
187 INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
188 cache_row,
189 )
190 }
191
192 async fn generate(
193 &self,
194 model: &str,
195 max_tokens: u64,
196 messages: Vec<RequestMessage>,
197 seed: Option<usize>,
198 cache_only: bool,
199 ) -> Result<Option<OpenAiResponse>> {
200 let response = self.lookup(model, max_tokens, &messages, seed)?;
201 if let Some(response) = response {
202 return Ok(Some(response));
203 }
204
205 if !cache_only {
206 self.mark_for_batch(model, max_tokens, &messages, seed)?;
207 }
208
209 Ok(None)
210 }
211
212 async fn sync_batches(&self) -> Result<()> {
213 let _batch_ids = self.upload_pending_requests().await?;
214 self.download_finished_batches().await
215 }
216
217 pub fn pending_batch_count(&self) -> Result<usize> {
218 let connection = self.connection.lock().unwrap();
219 let counts: Vec<i32> = connection.select(
220 sql!(SELECT COUNT(*) FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL),
221 )?()?;
222 Ok(counts.into_iter().next().unwrap_or(0) as usize)
223 }
224
225 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
226 for batch_id in batch_ids {
227 log::info!("Importing OpenAI batch {}", batch_id);
228
229 let batch_status = batches::retrieve_batch(
230 self.http_client.as_ref(),
231 OPEN_AI_API_URL,
232 &self.api_key,
233 batch_id,
234 )
235 .await
236 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
237
238 log::info!("Batch {} status: {}", batch_id, batch_status.status);
239
240 if batch_status.status != "completed" {
241 log::warn!(
242 "Batch {} is not completed (status: {}), skipping",
243 batch_id,
244 batch_status.status
245 );
246 continue;
247 }
248
249 let output_file_id = batch_status.output_file_id.ok_or_else(|| {
250 anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
251 })?;
252
253 let results_content = batches::download_file(
254 self.http_client.as_ref(),
255 OPEN_AI_API_URL,
256 &self.api_key,
257 &output_file_id,
258 )
259 .await
260 .map_err(|e| {
261 anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
262 })?;
263
264 let results = batches::parse_batch_output(&results_content)
265 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
266
267 let mut updates: Vec<(String, String, String)> = Vec::new();
268 let mut success_count = 0;
269 let mut error_count = 0;
270
271 for result in results {
272 let request_hash = result
273 .custom_id
274 .strip_prefix("req_hash_")
275 .unwrap_or(&result.custom_id)
276 .to_string();
277
278 if let Some(response_body) = result.response {
279 if response_body.status_code == 200 {
280 let response_json = serde_json::to_string(&response_body.body)?;
281 updates.push((request_hash, response_json, batch_id.clone()));
282 success_count += 1;
283 } else {
284 log::error!(
285 "Batch request {} failed with status {}",
286 request_hash,
287 response_body.status_code
288 );
289 let error_json = serde_json::json!({
290 "error": {
291 "type": "http_error",
292 "status_code": response_body.status_code
293 }
294 })
295 .to_string();
296 updates.push((request_hash, error_json, batch_id.clone()));
297 error_count += 1;
298 }
299 } else if let Some(error) = result.error {
300 log::error!(
301 "Batch request {} failed: {}: {}",
302 request_hash,
303 error.code,
304 error.message
305 );
306 let error_json = serde_json::json!({
307 "error": {
308 "type": error.code,
309 "message": error.message
310 }
311 })
312 .to_string();
313 updates.push((request_hash, error_json, batch_id.clone()));
314 error_count += 1;
315 }
316 }
317
318 let connection = self.connection.lock().unwrap();
319 connection.with_savepoint("batch_import", || {
320 let q = sql!(
321 INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
322 VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
323 );
324 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
325 for (request_hash, response_json, batch_id) in &updates {
326 exec((
327 request_hash.as_str(),
328 request_hash.as_str(),
329 response_json.as_str(),
330 batch_id.as_str(),
331 ))?;
332 }
333 Ok(())
334 })?;
335
336 log::info!(
337 "Imported batch {}: {} successful, {} errors",
338 batch_id,
339 success_count,
340 error_count
341 );
342 }
343
344 Ok(())
345 }
346
347 async fn download_finished_batches(&self) -> Result<()> {
348 let batch_ids: Vec<String> = {
349 let connection = self.connection.lock().unwrap();
350 let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
351 connection.select(q)?()?
352 };
353
354 for batch_id in &batch_ids {
355 let batch_status = batches::retrieve_batch(
356 self.http_client.as_ref(),
357 OPEN_AI_API_URL,
358 &self.api_key,
359 batch_id,
360 )
361 .await
362 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
363
364 log::info!("Batch {} status: {}", batch_id, batch_status.status);
365
366 if batch_status.status == "completed" {
367 let output_file_id = match batch_status.output_file_id {
368 Some(id) => id,
369 None => {
370 log::warn!("Batch {} completed but has no output file", batch_id);
371 continue;
372 }
373 };
374
375 let results_content = batches::download_file(
376 self.http_client.as_ref(),
377 OPEN_AI_API_URL,
378 &self.api_key,
379 &output_file_id,
380 )
381 .await
382 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
383
384 let results = batches::parse_batch_output(&results_content)
385 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
386
387 let mut updates: Vec<(String, String)> = Vec::new();
388 let mut success_count = 0;
389
390 for result in results {
391 let request_hash = result
392 .custom_id
393 .strip_prefix("req_hash_")
394 .unwrap_or(&result.custom_id)
395 .to_string();
396
397 if let Some(response_body) = result.response {
398 if response_body.status_code == 200 {
399 let response_json = serde_json::to_string(&response_body.body)?;
400 updates.push((response_json, request_hash));
401 success_count += 1;
402 } else {
403 log::error!(
404 "Batch request {} failed with status {}",
405 request_hash,
406 response_body.status_code
407 );
408 let error_json = serde_json::json!({
409 "error": {
410 "type": "http_error",
411 "status_code": response_body.status_code
412 }
413 })
414 .to_string();
415 updates.push((error_json, request_hash));
416 }
417 } else if let Some(error) = result.error {
418 log::error!(
419 "Batch request {} failed: {}: {}",
420 request_hash,
421 error.code,
422 error.message
423 );
424 let error_json = serde_json::json!({
425 "error": {
426 "type": error.code,
427 "message": error.message
428 }
429 })
430 .to_string();
431 updates.push((error_json, request_hash));
432 }
433 }
434
435 let connection = self.connection.lock().unwrap();
436 connection.with_savepoint("batch_download", || {
437 let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
438 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
439 for (response_json, request_hash) in &updates {
440 exec((response_json.as_str(), request_hash.as_str()))?;
441 }
442 Ok(())
443 })?;
444 log::info!("Downloaded {} successful requests", success_count);
445 }
446 }
447
448 Ok(())
449 }
450
451 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
452 const BATCH_CHUNK_SIZE: i32 = 16_000;
453 let mut all_batch_ids = Vec::new();
454 let mut total_uploaded = 0;
455
456 loop {
457 let rows: Vec<(String, String)> = {
458 let connection = self.connection.lock().unwrap();
459 let q = sql!(
460 SELECT request_hash, request FROM openai_cache
461 WHERE batch_id IS NULL AND response IS NULL
462 LIMIT ?
463 );
464 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
465 };
466
467 if rows.is_empty() {
468 break;
469 }
470
471 let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
472
473 let mut jsonl_content = String::new();
474 for (hash, request_str) in &rows {
475 let serializable_request: SerializableRequest =
476 serde_json::from_str(request_str).unwrap();
477
478 let messages: Vec<RequestMessage> = serializable_request
479 .messages
480 .into_iter()
481 .map(|msg| match msg.role.as_str() {
482 "user" => RequestMessage::User {
483 content: MessageContent::Plain(msg.content),
484 },
485 "assistant" => RequestMessage::Assistant {
486 content: Some(MessageContent::Plain(msg.content)),
487 tool_calls: Vec::new(),
488 reasoning_content: None,
489 },
490 "system" => RequestMessage::System {
491 content: MessageContent::Plain(msg.content),
492 },
493 _ => RequestMessage::User {
494 content: MessageContent::Plain(msg.content),
495 },
496 })
497 .collect();
498
499 let request = OpenAiRequest {
500 model: serializable_request.model,
501 messages,
502 stream: false,
503 stream_options: None,
504 max_completion_tokens: Some(serializable_request.max_tokens),
505 stop: Vec::new(),
506 temperature: None,
507 tool_choice: None,
508 parallel_tool_calls: None,
509 tools: Vec::new(),
510 prompt_cache_key: None,
511 reasoning_effort: None,
512 };
513
514 let custom_id = format!("req_hash_{}", hash);
515 let batch_item = batches::BatchRequestItem::new(custom_id, request);
516 let line = batch_item
517 .to_jsonl_line()
518 .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
519 jsonl_content.push_str(&line);
520 jsonl_content.push('\n');
521 }
522
523 let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
524 let file_obj = batches::upload_batch_file(
525 self.http_client.as_ref(),
526 OPEN_AI_API_URL,
527 &self.api_key,
528 &filename,
529 jsonl_content.into_bytes(),
530 )
531 .await
532 .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
533
534 let batch = batches::create_batch(
535 self.http_client.as_ref(),
536 OPEN_AI_API_URL,
537 &self.api_key,
538 batches::CreateBatchRequest::new(file_obj.id),
539 )
540 .await
541 .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
542
543 {
544 let connection = self.connection.lock().unwrap();
545 connection.with_savepoint("batch_upload", || {
546 let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
547 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
548 for hash in &request_hashes {
549 exec((batch.id.as_str(), hash.as_str()))?;
550 }
551 Ok(())
552 })?;
553 }
554
555 let batch_len = rows.len();
556 total_uploaded += batch_len;
557 log::info!(
558 "Uploaded batch {} with {} requests ({} total)",
559 batch.id,
560 batch_len,
561 total_uploaded
562 );
563
564 all_batch_ids.push(batch.id);
565 }
566
567 if !all_batch_ids.is_empty() {
568 log::info!(
569 "Finished uploading {} batches with {} total requests",
570 all_batch_ids.len(),
571 total_uploaded
572 );
573 }
574
575 Ok(all_batch_ids)
576 }
577
578 fn request_hash(
579 model: &str,
580 max_tokens: u64,
581 messages: &[RequestMessage],
582 seed: Option<usize>,
583 ) -> String {
584 let mut hasher = std::hash::DefaultHasher::new();
585 "openai".hash(&mut hasher);
586 model.hash(&mut hasher);
587 max_tokens.hash(&mut hasher);
588 for msg in messages {
589 message_content_to_string(msg).hash(&mut hasher);
590 }
591 if let Some(seed) = seed {
592 seed.hash(&mut hasher);
593 }
594 let request_hash = hasher.finish();
595 format!("{request_hash:016x}")
596 }
597}
598
599fn message_role_to_string(msg: &RequestMessage) -> String {
600 match msg {
601 RequestMessage::User { .. } => "user".to_string(),
602 RequestMessage::Assistant { .. } => "assistant".to_string(),
603 RequestMessage::System { .. } => "system".to_string(),
604 RequestMessage::Tool { .. } => "tool".to_string(),
605 }
606}
607
608fn message_content_to_string(msg: &RequestMessage) -> String {
609 match msg {
610 RequestMessage::User { content } => content_to_string(content),
611 RequestMessage::Assistant { content, .. } => {
612 content.as_ref().map(content_to_string).unwrap_or_default()
613 }
614 RequestMessage::System { content } => content_to_string(content),
615 RequestMessage::Tool { content, .. } => content_to_string(content),
616 }
617}
618
619fn content_to_string(content: &MessageContent) -> String {
620 match content {
621 MessageContent::Plain(text) => text.clone(),
622 MessageContent::Multipart(parts) => parts
623 .iter()
624 .filter_map(|part| match part {
625 open_ai::MessagePart::Text { text } => Some(text.clone()),
626 _ => None,
627 })
628 .collect::<Vec<String>>()
629 .join("\n"),
630 }
631}
632
633pub enum OpenAiClient {
634 Plain(PlainOpenAiClient),
635 Batch(BatchingOpenAiClient),
636 #[allow(dead_code)]
637 Dummy,
638}
639
640impl OpenAiClient {
641 pub fn plain() -> Result<Self> {
642 Ok(Self::Plain(PlainOpenAiClient::new()?))
643 }
644
645 pub fn batch(cache_path: &Path) -> Result<Self> {
646 Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
647 }
648
649 #[allow(dead_code)]
650 pub fn dummy() -> Self {
651 Self::Dummy
652 }
653
654 pub async fn generate(
655 &self,
656 model: &str,
657 max_tokens: u64,
658 messages: Vec<RequestMessage>,
659 seed: Option<usize>,
660 cache_only: bool,
661 ) -> Result<Option<OpenAiResponse>> {
662 match self {
663 OpenAiClient::Plain(plain_client) => plain_client
664 .generate(model, max_tokens, messages)
665 .await
666 .map(Some),
667 OpenAiClient::Batch(batching_client) => {
668 batching_client
669 .generate(model, max_tokens, messages, seed, cache_only)
670 .await
671 }
672 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
673 }
674 }
675
676 pub async fn sync_batches(&self) -> Result<()> {
677 match self {
678 OpenAiClient::Plain(_) => Ok(()),
679 OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
680 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
681 }
682 }
683
684 pub fn pending_batch_count(&self) -> Result<usize> {
685 match self {
686 OpenAiClient::Plain(_) => Ok(0),
687 OpenAiClient::Batch(batching_client) => batching_client.pending_batch_count(),
688 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
689 }
690 }
691
692 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
693 match self {
694 OpenAiClient::Plain(_) => {
695 anyhow::bail!("Import batches is only supported with batching client")
696 }
697 OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
698 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
699 }
700 }
701}