1use anyhow::Result;
2use http_client::HttpClient;
3use indoc::indoc;
4use open_ai::{
5 MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
6 Response as OpenAiResponse, batches, non_streaming_completion,
7};
8use reqwest_client::ReqwestClient;
9use sqlez::bindable::Bind;
10use sqlez::bindable::StaticColumnCount;
11use sqlez_macros::sql;
12use std::hash::Hash;
13use std::hash::Hasher;
14use std::path::Path;
15use std::sync::{Arc, Mutex};
16
17pub struct PlainOpenAiClient {
18 pub http_client: Arc<dyn HttpClient>,
19 pub api_key: String,
20}
21
22impl PlainOpenAiClient {
23 pub fn new() -> Result<Self> {
24 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
25 let api_key = std::env::var("OPENAI_API_KEY")
26 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
27 Ok(Self {
28 http_client,
29 api_key,
30 })
31 }
32
33 pub async fn generate(
34 &self,
35 model: &str,
36 max_tokens: u64,
37 messages: Vec<RequestMessage>,
38 ) -> Result<OpenAiResponse> {
39 let request = OpenAiRequest {
40 model: model.to_string(),
41 messages,
42 stream: false,
43 max_completion_tokens: Some(max_tokens),
44 stop: Vec::new(),
45 temperature: None,
46 tool_choice: None,
47 parallel_tool_calls: None,
48 tools: Vec::new(),
49 prompt_cache_key: None,
50 reasoning_effort: None,
51 };
52
53 let response = non_streaming_completion(
54 self.http_client.as_ref(),
55 OPEN_AI_API_URL,
56 &self.api_key,
57 request,
58 )
59 .await
60 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
61
62 Ok(response)
63 }
64}
65
66pub struct BatchingOpenAiClient {
67 connection: Mutex<sqlez::connection::Connection>,
68 http_client: Arc<dyn HttpClient>,
69 api_key: String,
70}
71
72struct CacheRow {
73 request_hash: String,
74 request: Option<String>,
75 response: Option<String>,
76 batch_id: Option<String>,
77}
78
79impl StaticColumnCount for CacheRow {
80 fn column_count() -> usize {
81 4
82 }
83}
84
85impl Bind for CacheRow {
86 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
87 let next_index = statement.bind(&self.request_hash, start_index)?;
88 let next_index = statement.bind(&self.request, next_index)?;
89 let next_index = statement.bind(&self.response, next_index)?;
90 let next_index = statement.bind(&self.batch_id, next_index)?;
91 Ok(next_index)
92 }
93}
94
95#[derive(serde::Serialize, serde::Deserialize)]
96struct SerializableRequest {
97 model: String,
98 max_tokens: u64,
99 messages: Vec<SerializableMessage>,
100}
101
102#[derive(serde::Serialize, serde::Deserialize)]
103struct SerializableMessage {
104 role: String,
105 content: String,
106}
107
108impl BatchingOpenAiClient {
109 fn new(cache_path: &Path) -> Result<Self> {
110 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
111 let api_key = std::env::var("OPENAI_API_KEY")
112 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
113
114 let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
115 let mut statement = sqlez::statement::Statement::prepare(
116 &connection,
117 indoc! {"
118 CREATE TABLE IF NOT EXISTS openai_cache (
119 request_hash TEXT PRIMARY KEY,
120 request TEXT,
121 response TEXT,
122 batch_id TEXT
123 );
124 "},
125 )?;
126 statement.exec()?;
127 drop(statement);
128
129 Ok(Self {
130 connection: Mutex::new(connection),
131 http_client,
132 api_key,
133 })
134 }
135
136 pub fn lookup(
137 &self,
138 model: &str,
139 max_tokens: u64,
140 messages: &[RequestMessage],
141 seed: Option<usize>,
142 ) -> Result<Option<OpenAiResponse>> {
143 let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
144 let connection = self.connection.lock().unwrap();
145 let response: Vec<String> = connection.select_bound(
146 &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
147 )?(request_hash_str.as_str())?;
148 Ok(response
149 .into_iter()
150 .next()
151 .and_then(|text| serde_json::from_str(&text).ok()))
152 }
153
154 pub fn mark_for_batch(
155 &self,
156 model: &str,
157 max_tokens: u64,
158 messages: &[RequestMessage],
159 seed: Option<usize>,
160 ) -> Result<()> {
161 let request_hash = Self::request_hash(model, max_tokens, messages, seed);
162
163 let serializable_messages: Vec<SerializableMessage> = messages
164 .iter()
165 .map(|msg| SerializableMessage {
166 role: message_role_to_string(msg),
167 content: message_content_to_string(msg),
168 })
169 .collect();
170
171 let serializable_request = SerializableRequest {
172 model: model.to_string(),
173 max_tokens,
174 messages: serializable_messages,
175 };
176
177 let request = Some(serde_json::to_string(&serializable_request)?);
178 let cache_row = CacheRow {
179 request_hash,
180 request,
181 response: None,
182 batch_id: None,
183 };
184 let connection = self.connection.lock().unwrap();
185 connection.exec_bound::<CacheRow>(sql!(
186 INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
187 cache_row,
188 )
189 }
190
191 async fn generate(
192 &self,
193 model: &str,
194 max_tokens: u64,
195 messages: Vec<RequestMessage>,
196 seed: Option<usize>,
197 cache_only: bool,
198 ) -> Result<Option<OpenAiResponse>> {
199 let response = self.lookup(model, max_tokens, &messages, seed)?;
200 if let Some(response) = response {
201 return Ok(Some(response));
202 }
203
204 if !cache_only {
205 self.mark_for_batch(model, max_tokens, &messages, seed)?;
206 }
207
208 Ok(None)
209 }
210
211 async fn sync_batches(&self) -> Result<()> {
212 let _batch_ids = self.upload_pending_requests().await?;
213 self.download_finished_batches().await
214 }
215
216 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
217 for batch_id in batch_ids {
218 log::info!("Importing OpenAI batch {}", batch_id);
219
220 let batch_status = batches::retrieve_batch(
221 self.http_client.as_ref(),
222 OPEN_AI_API_URL,
223 &self.api_key,
224 batch_id,
225 )
226 .await
227 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
228
229 log::info!("Batch {} status: {}", batch_id, batch_status.status);
230
231 if batch_status.status != "completed" {
232 log::warn!(
233 "Batch {} is not completed (status: {}), skipping",
234 batch_id,
235 batch_status.status
236 );
237 continue;
238 }
239
240 let output_file_id = batch_status.output_file_id.ok_or_else(|| {
241 anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
242 })?;
243
244 let results_content = batches::download_file(
245 self.http_client.as_ref(),
246 OPEN_AI_API_URL,
247 &self.api_key,
248 &output_file_id,
249 )
250 .await
251 .map_err(|e| {
252 anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
253 })?;
254
255 let results = batches::parse_batch_output(&results_content)
256 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
257
258 let mut updates: Vec<(String, String, String)> = Vec::new();
259 let mut success_count = 0;
260 let mut error_count = 0;
261
262 for result in results {
263 let request_hash = result
264 .custom_id
265 .strip_prefix("req_hash_")
266 .unwrap_or(&result.custom_id)
267 .to_string();
268
269 if let Some(response_body) = result.response {
270 if response_body.status_code == 200 {
271 let response_json = serde_json::to_string(&response_body.body)?;
272 updates.push((request_hash, response_json, batch_id.clone()));
273 success_count += 1;
274 } else {
275 log::error!(
276 "Batch request {} failed with status {}",
277 request_hash,
278 response_body.status_code
279 );
280 let error_json = serde_json::json!({
281 "error": {
282 "type": "http_error",
283 "status_code": response_body.status_code
284 }
285 })
286 .to_string();
287 updates.push((request_hash, error_json, batch_id.clone()));
288 error_count += 1;
289 }
290 } else if let Some(error) = result.error {
291 log::error!(
292 "Batch request {} failed: {}: {}",
293 request_hash,
294 error.code,
295 error.message
296 );
297 let error_json = serde_json::json!({
298 "error": {
299 "type": error.code,
300 "message": error.message
301 }
302 })
303 .to_string();
304 updates.push((request_hash, error_json, batch_id.clone()));
305 error_count += 1;
306 }
307 }
308
309 let connection = self.connection.lock().unwrap();
310 connection.with_savepoint("batch_import", || {
311 let q = sql!(
312 INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
313 VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
314 );
315 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
316 for (request_hash, response_json, batch_id) in &updates {
317 exec((
318 request_hash.as_str(),
319 request_hash.as_str(),
320 response_json.as_str(),
321 batch_id.as_str(),
322 ))?;
323 }
324 Ok(())
325 })?;
326
327 log::info!(
328 "Imported batch {}: {} successful, {} errors",
329 batch_id,
330 success_count,
331 error_count
332 );
333 }
334
335 Ok(())
336 }
337
338 async fn download_finished_batches(&self) -> Result<()> {
339 let batch_ids: Vec<String> = {
340 let connection = self.connection.lock().unwrap();
341 let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
342 connection.select(q)?()?
343 };
344
345 for batch_id in &batch_ids {
346 let batch_status = batches::retrieve_batch(
347 self.http_client.as_ref(),
348 OPEN_AI_API_URL,
349 &self.api_key,
350 batch_id,
351 )
352 .await
353 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
354
355 log::info!("Batch {} status: {}", batch_id, batch_status.status);
356
357 if batch_status.status == "completed" {
358 let output_file_id = match batch_status.output_file_id {
359 Some(id) => id,
360 None => {
361 log::warn!("Batch {} completed but has no output file", batch_id);
362 continue;
363 }
364 };
365
366 let results_content = batches::download_file(
367 self.http_client.as_ref(),
368 OPEN_AI_API_URL,
369 &self.api_key,
370 &output_file_id,
371 )
372 .await
373 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
374
375 let results = batches::parse_batch_output(&results_content)
376 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
377
378 let mut updates: Vec<(String, String)> = Vec::new();
379 let mut success_count = 0;
380
381 for result in results {
382 let request_hash = result
383 .custom_id
384 .strip_prefix("req_hash_")
385 .unwrap_or(&result.custom_id)
386 .to_string();
387
388 if let Some(response_body) = result.response {
389 if response_body.status_code == 200 {
390 let response_json = serde_json::to_string(&response_body.body)?;
391 updates.push((response_json, request_hash));
392 success_count += 1;
393 } else {
394 log::error!(
395 "Batch request {} failed with status {}",
396 request_hash,
397 response_body.status_code
398 );
399 let error_json = serde_json::json!({
400 "error": {
401 "type": "http_error",
402 "status_code": response_body.status_code
403 }
404 })
405 .to_string();
406 updates.push((error_json, request_hash));
407 }
408 } else if let Some(error) = result.error {
409 log::error!(
410 "Batch request {} failed: {}: {}",
411 request_hash,
412 error.code,
413 error.message
414 );
415 let error_json = serde_json::json!({
416 "error": {
417 "type": error.code,
418 "message": error.message
419 }
420 })
421 .to_string();
422 updates.push((error_json, request_hash));
423 }
424 }
425
426 let connection = self.connection.lock().unwrap();
427 connection.with_savepoint("batch_download", || {
428 let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
429 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
430 for (response_json, request_hash) in &updates {
431 exec((response_json.as_str(), request_hash.as_str()))?;
432 }
433 Ok(())
434 })?;
435 log::info!("Downloaded {} successful requests", success_count);
436 }
437 }
438
439 Ok(())
440 }
441
442 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
443 const BATCH_CHUNK_SIZE: i32 = 16_000;
444 let mut all_batch_ids = Vec::new();
445 let mut total_uploaded = 0;
446
447 loop {
448 let rows: Vec<(String, String)> = {
449 let connection = self.connection.lock().unwrap();
450 let q = sql!(
451 SELECT request_hash, request FROM openai_cache
452 WHERE batch_id IS NULL AND response IS NULL
453 LIMIT ?
454 );
455 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
456 };
457
458 if rows.is_empty() {
459 break;
460 }
461
462 let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
463
464 let mut jsonl_content = String::new();
465 for (hash, request_str) in &rows {
466 let serializable_request: SerializableRequest =
467 serde_json::from_str(request_str).unwrap();
468
469 let messages: Vec<RequestMessage> = serializable_request
470 .messages
471 .into_iter()
472 .map(|msg| match msg.role.as_str() {
473 "user" => RequestMessage::User {
474 content: MessageContent::Plain(msg.content),
475 },
476 "assistant" => RequestMessage::Assistant {
477 content: Some(MessageContent::Plain(msg.content)),
478 tool_calls: Vec::new(),
479 },
480 "system" => RequestMessage::System {
481 content: MessageContent::Plain(msg.content),
482 },
483 _ => RequestMessage::User {
484 content: MessageContent::Plain(msg.content),
485 },
486 })
487 .collect();
488
489 let request = OpenAiRequest {
490 model: serializable_request.model,
491 messages,
492 stream: false,
493 max_completion_tokens: Some(serializable_request.max_tokens),
494 stop: Vec::new(),
495 temperature: None,
496 tool_choice: None,
497 parallel_tool_calls: None,
498 tools: Vec::new(),
499 prompt_cache_key: None,
500 reasoning_effort: None,
501 };
502
503 let custom_id = format!("req_hash_{}", hash);
504 let batch_item = batches::BatchRequestItem::new(custom_id, request);
505 let line = batch_item
506 .to_jsonl_line()
507 .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
508 jsonl_content.push_str(&line);
509 jsonl_content.push('\n');
510 }
511
512 let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
513 let file_obj = batches::upload_batch_file(
514 self.http_client.as_ref(),
515 OPEN_AI_API_URL,
516 &self.api_key,
517 &filename,
518 jsonl_content.into_bytes(),
519 )
520 .await
521 .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
522
523 let batch = batches::create_batch(
524 self.http_client.as_ref(),
525 OPEN_AI_API_URL,
526 &self.api_key,
527 batches::CreateBatchRequest::new(file_obj.id),
528 )
529 .await
530 .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
531
532 {
533 let connection = self.connection.lock().unwrap();
534 connection.with_savepoint("batch_upload", || {
535 let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
536 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
537 for hash in &request_hashes {
538 exec((batch.id.as_str(), hash.as_str()))?;
539 }
540 Ok(())
541 })?;
542 }
543
544 let batch_len = rows.len();
545 total_uploaded += batch_len;
546 log::info!(
547 "Uploaded batch {} with {} requests ({} total)",
548 batch.id,
549 batch_len,
550 total_uploaded
551 );
552
553 all_batch_ids.push(batch.id);
554 }
555
556 if !all_batch_ids.is_empty() {
557 log::info!(
558 "Finished uploading {} batches with {} total requests",
559 all_batch_ids.len(),
560 total_uploaded
561 );
562 }
563
564 Ok(all_batch_ids)
565 }
566
567 fn request_hash(
568 model: &str,
569 max_tokens: u64,
570 messages: &[RequestMessage],
571 seed: Option<usize>,
572 ) -> String {
573 let mut hasher = std::hash::DefaultHasher::new();
574 "openai".hash(&mut hasher);
575 model.hash(&mut hasher);
576 max_tokens.hash(&mut hasher);
577 for msg in messages {
578 message_content_to_string(msg).hash(&mut hasher);
579 }
580 if let Some(seed) = seed {
581 seed.hash(&mut hasher);
582 }
583 let request_hash = hasher.finish();
584 format!("{request_hash:016x}")
585 }
586}
587
588fn message_role_to_string(msg: &RequestMessage) -> String {
589 match msg {
590 RequestMessage::User { .. } => "user".to_string(),
591 RequestMessage::Assistant { .. } => "assistant".to_string(),
592 RequestMessage::System { .. } => "system".to_string(),
593 RequestMessage::Tool { .. } => "tool".to_string(),
594 }
595}
596
597fn message_content_to_string(msg: &RequestMessage) -> String {
598 match msg {
599 RequestMessage::User { content } => content_to_string(content),
600 RequestMessage::Assistant { content, .. } => {
601 content.as_ref().map(content_to_string).unwrap_or_default()
602 }
603 RequestMessage::System { content } => content_to_string(content),
604 RequestMessage::Tool { content, .. } => content_to_string(content),
605 }
606}
607
608fn content_to_string(content: &MessageContent) -> String {
609 match content {
610 MessageContent::Plain(text) => text.clone(),
611 MessageContent::Multipart(parts) => parts
612 .iter()
613 .filter_map(|part| match part {
614 open_ai::MessagePart::Text { text } => Some(text.clone()),
615 _ => None,
616 })
617 .collect::<Vec<String>>()
618 .join("\n"),
619 }
620}
621
622pub enum OpenAiClient {
623 Plain(PlainOpenAiClient),
624 Batch(BatchingOpenAiClient),
625 #[allow(dead_code)]
626 Dummy,
627}
628
629impl OpenAiClient {
630 pub fn plain() -> Result<Self> {
631 Ok(Self::Plain(PlainOpenAiClient::new()?))
632 }
633
634 pub fn batch(cache_path: &Path) -> Result<Self> {
635 Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
636 }
637
638 #[allow(dead_code)]
639 pub fn dummy() -> Self {
640 Self::Dummy
641 }
642
643 pub async fn generate(
644 &self,
645 model: &str,
646 max_tokens: u64,
647 messages: Vec<RequestMessage>,
648 seed: Option<usize>,
649 cache_only: bool,
650 ) -> Result<Option<OpenAiResponse>> {
651 match self {
652 OpenAiClient::Plain(plain_client) => plain_client
653 .generate(model, max_tokens, messages)
654 .await
655 .map(Some),
656 OpenAiClient::Batch(batching_client) => {
657 batching_client
658 .generate(model, max_tokens, messages, seed, cache_only)
659 .await
660 }
661 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
662 }
663 }
664
665 pub async fn sync_batches(&self) -> Result<()> {
666 match self {
667 OpenAiClient::Plain(_) => Ok(()),
668 OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
669 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
670 }
671 }
672
673 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
674 match self {
675 OpenAiClient::Plain(_) => {
676 anyhow::bail!("Import batches is only supported with batching client")
677 }
678 OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
679 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
680 }
681 }
682}