1use anyhow::Result;
2use http_client::HttpClient;
3use indoc::indoc;
4use open_ai::{
5 MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
6 Response as OpenAiResponse, batches, non_streaming_completion,
7};
8use reqwest_client::ReqwestClient;
9use sqlez::bindable::Bind;
10use sqlez::bindable::StaticColumnCount;
11use sqlez_macros::sql;
12use std::hash::Hash;
13use std::hash::Hasher;
14use std::path::Path;
15use std::sync::{Arc, Mutex};
16
17pub struct PlainOpenAiClient {
18 pub http_client: Arc<dyn HttpClient>,
19 pub api_key: String,
20}
21
22impl PlainOpenAiClient {
23 pub fn new() -> Result<Self> {
24 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
25 let api_key = std::env::var("OPENAI_API_KEY")
26 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
27 Ok(Self {
28 http_client,
29 api_key,
30 })
31 }
32
33 pub async fn generate(
34 &self,
35 model: &str,
36 max_tokens: u64,
37 messages: Vec<RequestMessage>,
38 ) -> Result<OpenAiResponse> {
39 let request = OpenAiRequest {
40 model: model.to_string(),
41 messages,
42 stream: false,
43 max_completion_tokens: Some(max_tokens),
44 stop: Vec::new(),
45 temperature: None,
46 tool_choice: None,
47 parallel_tool_calls: None,
48 tools: Vec::new(),
49 prompt_cache_key: None,
50 reasoning_effort: None,
51 };
52
53 let response = non_streaming_completion(
54 self.http_client.as_ref(),
55 OPEN_AI_API_URL,
56 &self.api_key,
57 request,
58 )
59 .await
60 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
61
62 Ok(response)
63 }
64}
65
66pub struct BatchingOpenAiClient {
67 connection: Mutex<sqlez::connection::Connection>,
68 http_client: Arc<dyn HttpClient>,
69 api_key: String,
70}
71
72struct CacheRow {
73 request_hash: String,
74 request: Option<String>,
75 response: Option<String>,
76 batch_id: Option<String>,
77}
78
79impl StaticColumnCount for CacheRow {
80 fn column_count() -> usize {
81 4
82 }
83}
84
85impl Bind for CacheRow {
86 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
87 let next_index = statement.bind(&self.request_hash, start_index)?;
88 let next_index = statement.bind(&self.request, next_index)?;
89 let next_index = statement.bind(&self.response, next_index)?;
90 let next_index = statement.bind(&self.batch_id, next_index)?;
91 Ok(next_index)
92 }
93}
94
95#[derive(serde::Serialize, serde::Deserialize)]
96struct SerializableRequest {
97 model: String,
98 max_tokens: u64,
99 messages: Vec<SerializableMessage>,
100}
101
102#[derive(serde::Serialize, serde::Deserialize)]
103struct SerializableMessage {
104 role: String,
105 content: String,
106}
107
108impl BatchingOpenAiClient {
109 fn new(cache_path: &Path) -> Result<Self> {
110 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
111 let api_key = std::env::var("OPENAI_API_KEY")
112 .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
113
114 let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
115 let mut statement = sqlez::statement::Statement::prepare(
116 &connection,
117 indoc! {"
118 CREATE TABLE IF NOT EXISTS openai_cache (
119 request_hash TEXT PRIMARY KEY,
120 request TEXT,
121 response TEXT,
122 batch_id TEXT
123 );
124 "},
125 )?;
126 statement.exec()?;
127 drop(statement);
128
129 Ok(Self {
130 connection: Mutex::new(connection),
131 http_client,
132 api_key,
133 })
134 }
135
136 pub fn lookup(
137 &self,
138 model: &str,
139 max_tokens: u64,
140 messages: &[RequestMessage],
141 ) -> Result<Option<OpenAiResponse>> {
142 let request_hash_str = Self::request_hash(model, max_tokens, messages);
143 let connection = self.connection.lock().unwrap();
144 let response: Vec<String> = connection.select_bound(
145 &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
146 )?(request_hash_str.as_str())?;
147 Ok(response
148 .into_iter()
149 .next()
150 .and_then(|text| serde_json::from_str(&text).ok()))
151 }
152
153 pub fn mark_for_batch(
154 &self,
155 model: &str,
156 max_tokens: u64,
157 messages: &[RequestMessage],
158 ) -> Result<()> {
159 let request_hash = Self::request_hash(model, max_tokens, messages);
160
161 let serializable_messages: Vec<SerializableMessage> = messages
162 .iter()
163 .map(|msg| SerializableMessage {
164 role: message_role_to_string(msg),
165 content: message_content_to_string(msg),
166 })
167 .collect();
168
169 let serializable_request = SerializableRequest {
170 model: model.to_string(),
171 max_tokens,
172 messages: serializable_messages,
173 };
174
175 let request = Some(serde_json::to_string(&serializable_request)?);
176 let cache_row = CacheRow {
177 request_hash,
178 request,
179 response: None,
180 batch_id: None,
181 };
182 let connection = self.connection.lock().unwrap();
183 connection.exec_bound::<CacheRow>(sql!(
184 INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
185 cache_row,
186 )
187 }
188
189 async fn generate(
190 &self,
191 model: &str,
192 max_tokens: u64,
193 messages: Vec<RequestMessage>,
194 ) -> Result<Option<OpenAiResponse>> {
195 let response = self.lookup(model, max_tokens, &messages)?;
196 if let Some(response) = response {
197 return Ok(Some(response));
198 }
199
200 self.mark_for_batch(model, max_tokens, &messages)?;
201
202 Ok(None)
203 }
204
205 async fn sync_batches(&self) -> Result<()> {
206 let _batch_ids = self.upload_pending_requests().await?;
207 self.download_finished_batches().await
208 }
209
210 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
211 for batch_id in batch_ids {
212 log::info!("Importing OpenAI batch {}", batch_id);
213
214 let batch_status = batches::retrieve_batch(
215 self.http_client.as_ref(),
216 OPEN_AI_API_URL,
217 &self.api_key,
218 batch_id,
219 )
220 .await
221 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
222
223 log::info!("Batch {} status: {}", batch_id, batch_status.status);
224
225 if batch_status.status != "completed" {
226 log::warn!(
227 "Batch {} is not completed (status: {}), skipping",
228 batch_id,
229 batch_status.status
230 );
231 continue;
232 }
233
234 let output_file_id = batch_status.output_file_id.ok_or_else(|| {
235 anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
236 })?;
237
238 let results_content = batches::download_file(
239 self.http_client.as_ref(),
240 OPEN_AI_API_URL,
241 &self.api_key,
242 &output_file_id,
243 )
244 .await
245 .map_err(|e| {
246 anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
247 })?;
248
249 let results = batches::parse_batch_output(&results_content)
250 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
251
252 let mut updates: Vec<(String, String, String)> = Vec::new();
253 let mut success_count = 0;
254 let mut error_count = 0;
255
256 for result in results {
257 let request_hash = result
258 .custom_id
259 .strip_prefix("req_hash_")
260 .unwrap_or(&result.custom_id)
261 .to_string();
262
263 if let Some(response_body) = result.response {
264 if response_body.status_code == 200 {
265 let response_json = serde_json::to_string(&response_body.body)?;
266 updates.push((request_hash, response_json, batch_id.clone()));
267 success_count += 1;
268 } else {
269 log::error!(
270 "Batch request {} failed with status {}",
271 request_hash,
272 response_body.status_code
273 );
274 let error_json = serde_json::json!({
275 "error": {
276 "type": "http_error",
277 "status_code": response_body.status_code
278 }
279 })
280 .to_string();
281 updates.push((request_hash, error_json, batch_id.clone()));
282 error_count += 1;
283 }
284 } else if let Some(error) = result.error {
285 log::error!(
286 "Batch request {} failed: {}: {}",
287 request_hash,
288 error.code,
289 error.message
290 );
291 let error_json = serde_json::json!({
292 "error": {
293 "type": error.code,
294 "message": error.message
295 }
296 })
297 .to_string();
298 updates.push((request_hash, error_json, batch_id.clone()));
299 error_count += 1;
300 }
301 }
302
303 let connection = self.connection.lock().unwrap();
304 connection.with_savepoint("batch_import", || {
305 let q = sql!(
306 INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
307 VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
308 );
309 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
310 for (request_hash, response_json, batch_id) in &updates {
311 exec((
312 request_hash.as_str(),
313 request_hash.as_str(),
314 response_json.as_str(),
315 batch_id.as_str(),
316 ))?;
317 }
318 Ok(())
319 })?;
320
321 log::info!(
322 "Imported batch {}: {} successful, {} errors",
323 batch_id,
324 success_count,
325 error_count
326 );
327 }
328
329 Ok(())
330 }
331
332 async fn download_finished_batches(&self) -> Result<()> {
333 let batch_ids: Vec<String> = {
334 let connection = self.connection.lock().unwrap();
335 let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
336 connection.select(q)?()?
337 };
338
339 for batch_id in &batch_ids {
340 let batch_status = batches::retrieve_batch(
341 self.http_client.as_ref(),
342 OPEN_AI_API_URL,
343 &self.api_key,
344 batch_id,
345 )
346 .await
347 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
348
349 log::info!("Batch {} status: {}", batch_id, batch_status.status);
350
351 if batch_status.status == "completed" {
352 let output_file_id = match batch_status.output_file_id {
353 Some(id) => id,
354 None => {
355 log::warn!("Batch {} completed but has no output file", batch_id);
356 continue;
357 }
358 };
359
360 let results_content = batches::download_file(
361 self.http_client.as_ref(),
362 OPEN_AI_API_URL,
363 &self.api_key,
364 &output_file_id,
365 )
366 .await
367 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
368
369 let results = batches::parse_batch_output(&results_content)
370 .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
371
372 let mut updates: Vec<(String, String)> = Vec::new();
373 let mut success_count = 0;
374
375 for result in results {
376 let request_hash = result
377 .custom_id
378 .strip_prefix("req_hash_")
379 .unwrap_or(&result.custom_id)
380 .to_string();
381
382 if let Some(response_body) = result.response {
383 if response_body.status_code == 200 {
384 let response_json = serde_json::to_string(&response_body.body)?;
385 updates.push((response_json, request_hash));
386 success_count += 1;
387 } else {
388 log::error!(
389 "Batch request {} failed with status {}",
390 request_hash,
391 response_body.status_code
392 );
393 let error_json = serde_json::json!({
394 "error": {
395 "type": "http_error",
396 "status_code": response_body.status_code
397 }
398 })
399 .to_string();
400 updates.push((error_json, request_hash));
401 }
402 } else if let Some(error) = result.error {
403 log::error!(
404 "Batch request {} failed: {}: {}",
405 request_hash,
406 error.code,
407 error.message
408 );
409 let error_json = serde_json::json!({
410 "error": {
411 "type": error.code,
412 "message": error.message
413 }
414 })
415 .to_string();
416 updates.push((error_json, request_hash));
417 }
418 }
419
420 let connection = self.connection.lock().unwrap();
421 connection.with_savepoint("batch_download", || {
422 let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
423 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
424 for (response_json, request_hash) in &updates {
425 exec((response_json.as_str(), request_hash.as_str()))?;
426 }
427 Ok(())
428 })?;
429 log::info!("Downloaded {} successful requests", success_count);
430 }
431 }
432
433 Ok(())
434 }
435
436 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
437 const BATCH_CHUNK_SIZE: i32 = 16_000;
438 let mut all_batch_ids = Vec::new();
439 let mut total_uploaded = 0;
440
441 loop {
442 let rows: Vec<(String, String)> = {
443 let connection = self.connection.lock().unwrap();
444 let q = sql!(
445 SELECT request_hash, request FROM openai_cache
446 WHERE batch_id IS NULL AND response IS NULL
447 LIMIT ?
448 );
449 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
450 };
451
452 if rows.is_empty() {
453 break;
454 }
455
456 let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
457
458 let mut jsonl_content = String::new();
459 for (hash, request_str) in &rows {
460 let serializable_request: SerializableRequest =
461 serde_json::from_str(request_str).unwrap();
462
463 let messages: Vec<RequestMessage> = serializable_request
464 .messages
465 .into_iter()
466 .map(|msg| match msg.role.as_str() {
467 "user" => RequestMessage::User {
468 content: MessageContent::Plain(msg.content),
469 },
470 "assistant" => RequestMessage::Assistant {
471 content: Some(MessageContent::Plain(msg.content)),
472 tool_calls: Vec::new(),
473 },
474 "system" => RequestMessage::System {
475 content: MessageContent::Plain(msg.content),
476 },
477 _ => RequestMessage::User {
478 content: MessageContent::Plain(msg.content),
479 },
480 })
481 .collect();
482
483 let request = OpenAiRequest {
484 model: serializable_request.model,
485 messages,
486 stream: false,
487 max_completion_tokens: Some(serializable_request.max_tokens),
488 stop: Vec::new(),
489 temperature: None,
490 tool_choice: None,
491 parallel_tool_calls: None,
492 tools: Vec::new(),
493 prompt_cache_key: None,
494 reasoning_effort: None,
495 };
496
497 let custom_id = format!("req_hash_{}", hash);
498 let batch_item = batches::BatchRequestItem::new(custom_id, request);
499 let line = batch_item
500 .to_jsonl_line()
501 .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
502 jsonl_content.push_str(&line);
503 jsonl_content.push('\n');
504 }
505
506 let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
507 let file_obj = batches::upload_batch_file(
508 self.http_client.as_ref(),
509 OPEN_AI_API_URL,
510 &self.api_key,
511 &filename,
512 jsonl_content.into_bytes(),
513 )
514 .await
515 .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
516
517 let batch = batches::create_batch(
518 self.http_client.as_ref(),
519 OPEN_AI_API_URL,
520 &self.api_key,
521 batches::CreateBatchRequest::new(file_obj.id),
522 )
523 .await
524 .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
525
526 {
527 let connection = self.connection.lock().unwrap();
528 connection.with_savepoint("batch_upload", || {
529 let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
530 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
531 for hash in &request_hashes {
532 exec((batch.id.as_str(), hash.as_str()))?;
533 }
534 Ok(())
535 })?;
536 }
537
538 let batch_len = rows.len();
539 total_uploaded += batch_len;
540 log::info!(
541 "Uploaded batch {} with {} requests ({} total)",
542 batch.id,
543 batch_len,
544 total_uploaded
545 );
546
547 all_batch_ids.push(batch.id);
548 }
549
550 if !all_batch_ids.is_empty() {
551 log::info!(
552 "Finished uploading {} batches with {} total requests",
553 all_batch_ids.len(),
554 total_uploaded
555 );
556 }
557
558 Ok(all_batch_ids)
559 }
560
561 fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String {
562 let mut hasher = std::hash::DefaultHasher::new();
563 "openai".hash(&mut hasher);
564 model.hash(&mut hasher);
565 max_tokens.hash(&mut hasher);
566 for msg in messages {
567 message_content_to_string(msg).hash(&mut hasher);
568 }
569 let request_hash = hasher.finish();
570 format!("{request_hash:016x}")
571 }
572}
573
574fn message_role_to_string(msg: &RequestMessage) -> String {
575 match msg {
576 RequestMessage::User { .. } => "user".to_string(),
577 RequestMessage::Assistant { .. } => "assistant".to_string(),
578 RequestMessage::System { .. } => "system".to_string(),
579 RequestMessage::Tool { .. } => "tool".to_string(),
580 }
581}
582
583fn message_content_to_string(msg: &RequestMessage) -> String {
584 match msg {
585 RequestMessage::User { content } => content_to_string(content),
586 RequestMessage::Assistant { content, .. } => {
587 content.as_ref().map(content_to_string).unwrap_or_default()
588 }
589 RequestMessage::System { content } => content_to_string(content),
590 RequestMessage::Tool { content, .. } => content_to_string(content),
591 }
592}
593
594fn content_to_string(content: &MessageContent) -> String {
595 match content {
596 MessageContent::Plain(text) => text.clone(),
597 MessageContent::Multipart(parts) => parts
598 .iter()
599 .filter_map(|part| match part {
600 open_ai::MessagePart::Text { text } => Some(text.clone()),
601 _ => None,
602 })
603 .collect::<Vec<String>>()
604 .join("\n"),
605 }
606}
607
608pub enum OpenAiClient {
609 Plain(PlainOpenAiClient),
610 Batch(BatchingOpenAiClient),
611 #[allow(dead_code)]
612 Dummy,
613}
614
615impl OpenAiClient {
616 pub fn plain() -> Result<Self> {
617 Ok(Self::Plain(PlainOpenAiClient::new()?))
618 }
619
620 pub fn batch(cache_path: &Path) -> Result<Self> {
621 Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
622 }
623
624 #[allow(dead_code)]
625 pub fn dummy() -> Self {
626 Self::Dummy
627 }
628
629 pub async fn generate(
630 &self,
631 model: &str,
632 max_tokens: u64,
633 messages: Vec<RequestMessage>,
634 ) -> Result<Option<OpenAiResponse>> {
635 match self {
636 OpenAiClient::Plain(plain_client) => plain_client
637 .generate(model, max_tokens, messages)
638 .await
639 .map(Some),
640 OpenAiClient::Batch(batching_client) => {
641 batching_client.generate(model, max_tokens, messages).await
642 }
643 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
644 }
645 }
646
647 pub async fn sync_batches(&self) -> Result<()> {
648 match self {
649 OpenAiClient::Plain(_) => Ok(()),
650 OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
651 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
652 }
653 }
654
655 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
656 match self {
657 OpenAiClient::Plain(_) => {
658 anyhow::bail!("Import batches is only supported with batching client")
659 }
660 OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
661 OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
662 }
663 }
664}