openai_client.rs

  1use anyhow::Result;
  2use http_client::HttpClient;
  3use indoc::indoc;
  4use open_ai::{
  5    MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
  6    Response as OpenAiResponse, batches, non_streaming_completion,
  7};
  8use reqwest_client::ReqwestClient;
  9use sqlez::bindable::Bind;
 10use sqlez::bindable::StaticColumnCount;
 11use sqlez_macros::sql;
 12use std::hash::Hash;
 13use std::hash::Hasher;
 14use std::path::Path;
 15use std::sync::{Arc, Mutex};
 16
 17pub struct PlainOpenAiClient {
 18    pub http_client: Arc<dyn HttpClient>,
 19    pub api_key: String,
 20}
 21
 22impl PlainOpenAiClient {
 23    pub fn new() -> Result<Self> {
 24        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 25        let api_key = std::env::var("OPENAI_API_KEY")
 26            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
 27        Ok(Self {
 28            http_client,
 29            api_key,
 30        })
 31    }
 32
 33    pub async fn generate(
 34        &self,
 35        model: &str,
 36        max_tokens: u64,
 37        messages: Vec<RequestMessage>,
 38    ) -> Result<OpenAiResponse> {
 39        let request = OpenAiRequest {
 40            model: model.to_string(),
 41            messages,
 42            stream: false,
 43            stream_options: None,
 44            max_completion_tokens: Some(max_tokens),
 45            stop: Vec::new(),
 46            temperature: None,
 47            tool_choice: None,
 48            parallel_tool_calls: None,
 49            tools: Vec::new(),
 50            prompt_cache_key: None,
 51            reasoning_effort: None,
 52        };
 53
 54        let response = non_streaming_completion(
 55            self.http_client.as_ref(),
 56            OPEN_AI_API_URL,
 57            &self.api_key,
 58            request,
 59        )
 60        .await
 61        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 62
 63        Ok(response)
 64    }
 65}
 66
 67pub struct BatchingOpenAiClient {
 68    connection: Mutex<sqlez::connection::Connection>,
 69    http_client: Arc<dyn HttpClient>,
 70    api_key: String,
 71}
 72
 73struct CacheRow {
 74    request_hash: String,
 75    request: Option<String>,
 76    response: Option<String>,
 77    batch_id: Option<String>,
 78}
 79
 80impl StaticColumnCount for CacheRow {
 81    fn column_count() -> usize {
 82        4
 83    }
 84}
 85
 86impl Bind for CacheRow {
 87    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 88        let next_index = statement.bind(&self.request_hash, start_index)?;
 89        let next_index = statement.bind(&self.request, next_index)?;
 90        let next_index = statement.bind(&self.response, next_index)?;
 91        let next_index = statement.bind(&self.batch_id, next_index)?;
 92        Ok(next_index)
 93    }
 94}
 95
 96#[derive(serde::Serialize, serde::Deserialize)]
 97struct SerializableRequest {
 98    model: String,
 99    max_tokens: u64,
100    messages: Vec<SerializableMessage>,
101}
102
103#[derive(serde::Serialize, serde::Deserialize)]
104struct SerializableMessage {
105    role: String,
106    content: String,
107}
108
109impl BatchingOpenAiClient {
110    fn new(cache_path: &Path) -> Result<Self> {
111        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
112        let api_key = std::env::var("OPENAI_API_KEY")
113            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
114
115        let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
116        let mut statement = sqlez::statement::Statement::prepare(
117            &connection,
118            indoc! {"
119                CREATE TABLE IF NOT EXISTS openai_cache (
120                    request_hash TEXT PRIMARY KEY,
121                    request TEXT,
122                    response TEXT,
123                    batch_id TEXT
124                );
125                "},
126        )?;
127        statement.exec()?;
128        drop(statement);
129
130        Ok(Self {
131            connection: Mutex::new(connection),
132            http_client,
133            api_key,
134        })
135    }
136
137    pub fn lookup(
138        &self,
139        model: &str,
140        max_tokens: u64,
141        messages: &[RequestMessage],
142        seed: Option<usize>,
143    ) -> Result<Option<OpenAiResponse>> {
144        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
145        let connection = self.connection.lock().unwrap();
146        let response: Vec<String> = connection.select_bound(
147            &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
148        )?(request_hash_str.as_str())?;
149        Ok(response
150            .into_iter()
151            .next()
152            .and_then(|text| serde_json::from_str(&text).ok()))
153    }
154
155    pub fn mark_for_batch(
156        &self,
157        model: &str,
158        max_tokens: u64,
159        messages: &[RequestMessage],
160        seed: Option<usize>,
161    ) -> Result<()> {
162        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
163
164        let serializable_messages: Vec<SerializableMessage> = messages
165            .iter()
166            .map(|msg| SerializableMessage {
167                role: message_role_to_string(msg),
168                content: message_content_to_string(msg),
169            })
170            .collect();
171
172        let serializable_request = SerializableRequest {
173            model: model.to_string(),
174            max_tokens,
175            messages: serializable_messages,
176        };
177
178        let request = Some(serde_json::to_string(&serializable_request)?);
179        let cache_row = CacheRow {
180            request_hash,
181            request,
182            response: None,
183            batch_id: None,
184        };
185        let connection = self.connection.lock().unwrap();
186        connection.exec_bound::<CacheRow>(sql!(
187            INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
188            cache_row,
189        )
190    }
191
192    async fn generate(
193        &self,
194        model: &str,
195        max_tokens: u64,
196        messages: Vec<RequestMessage>,
197        seed: Option<usize>,
198        cache_only: bool,
199    ) -> Result<Option<OpenAiResponse>> {
200        let response = self.lookup(model, max_tokens, &messages, seed)?;
201        if let Some(response) = response {
202            return Ok(Some(response));
203        }
204
205        if !cache_only {
206            self.mark_for_batch(model, max_tokens, &messages, seed)?;
207        }
208
209        Ok(None)
210    }
211
212    async fn sync_batches(&self) -> Result<()> {
213        let _batch_ids = self.upload_pending_requests().await?;
214        self.download_finished_batches().await
215    }
216
217    pub fn pending_batch_count(&self) -> Result<usize> {
218        let connection = self.connection.lock().unwrap();
219        let counts: Vec<i32> = connection.select(
220            sql!(SELECT COUNT(*) FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL),
221        )?()?;
222        Ok(counts.into_iter().next().unwrap_or(0) as usize)
223    }
224
225    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
226        for batch_id in batch_ids {
227            log::info!("Importing OpenAI batch {}", batch_id);
228
229            let batch_status = batches::retrieve_batch(
230                self.http_client.as_ref(),
231                OPEN_AI_API_URL,
232                &self.api_key,
233                batch_id,
234            )
235            .await
236            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
237
238            log::info!("Batch {} status: {}", batch_id, batch_status.status);
239
240            if batch_status.status != "completed" {
241                log::warn!(
242                    "Batch {} is not completed (status: {}), skipping",
243                    batch_id,
244                    batch_status.status
245                );
246                continue;
247            }
248
249            let output_file_id = batch_status.output_file_id.ok_or_else(|| {
250                anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
251            })?;
252
253            let results_content = batches::download_file(
254                self.http_client.as_ref(),
255                OPEN_AI_API_URL,
256                &self.api_key,
257                &output_file_id,
258            )
259            .await
260            .map_err(|e| {
261                anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
262            })?;
263
264            let results = batches::parse_batch_output(&results_content)
265                .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
266
267            let mut updates: Vec<(String, String, String)> = Vec::new();
268            let mut success_count = 0;
269            let mut error_count = 0;
270
271            for result in results {
272                let request_hash = result
273                    .custom_id
274                    .strip_prefix("req_hash_")
275                    .unwrap_or(&result.custom_id)
276                    .to_string();
277
278                if let Some(response_body) = result.response {
279                    if response_body.status_code == 200 {
280                        let response_json = serde_json::to_string(&response_body.body)?;
281                        updates.push((request_hash, response_json, batch_id.clone()));
282                        success_count += 1;
283                    } else {
284                        log::error!(
285                            "Batch request {} failed with status {}",
286                            request_hash,
287                            response_body.status_code
288                        );
289                        let error_json = serde_json::json!({
290                            "error": {
291                                "type": "http_error",
292                                "status_code": response_body.status_code
293                            }
294                        })
295                        .to_string();
296                        updates.push((request_hash, error_json, batch_id.clone()));
297                        error_count += 1;
298                    }
299                } else if let Some(error) = result.error {
300                    log::error!(
301                        "Batch request {} failed: {}: {}",
302                        request_hash,
303                        error.code,
304                        error.message
305                    );
306                    let error_json = serde_json::json!({
307                        "error": {
308                            "type": error.code,
309                            "message": error.message
310                        }
311                    })
312                    .to_string();
313                    updates.push((request_hash, error_json, batch_id.clone()));
314                    error_count += 1;
315                }
316            }
317
318            let connection = self.connection.lock().unwrap();
319            connection.with_savepoint("batch_import", || {
320                let q = sql!(
321                    INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
322                    VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
323                );
324                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
325                for (request_hash, response_json, batch_id) in &updates {
326                    exec((
327                        request_hash.as_str(),
328                        request_hash.as_str(),
329                        response_json.as_str(),
330                        batch_id.as_str(),
331                    ))?;
332                }
333                Ok(())
334            })?;
335
336            log::info!(
337                "Imported batch {}: {} successful, {} errors",
338                batch_id,
339                success_count,
340                error_count
341            );
342        }
343
344        Ok(())
345    }
346
347    async fn download_finished_batches(&self) -> Result<()> {
348        let batch_ids: Vec<String> = {
349            let connection = self.connection.lock().unwrap();
350            let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
351            connection.select(q)?()?
352        };
353
354        for batch_id in &batch_ids {
355            let batch_status = batches::retrieve_batch(
356                self.http_client.as_ref(),
357                OPEN_AI_API_URL,
358                &self.api_key,
359                batch_id,
360            )
361            .await
362            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
363
364            log::info!("Batch {} status: {}", batch_id, batch_status.status);
365
366            if batch_status.status == "completed" {
367                let output_file_id = match batch_status.output_file_id {
368                    Some(id) => id,
369                    None => {
370                        log::warn!("Batch {} completed but has no output file", batch_id);
371                        continue;
372                    }
373                };
374
375                let results_content = batches::download_file(
376                    self.http_client.as_ref(),
377                    OPEN_AI_API_URL,
378                    &self.api_key,
379                    &output_file_id,
380                )
381                .await
382                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
383
384                let results = batches::parse_batch_output(&results_content)
385                    .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
386
387                let mut updates: Vec<(String, String)> = Vec::new();
388                let mut success_count = 0;
389
390                for result in results {
391                    let request_hash = result
392                        .custom_id
393                        .strip_prefix("req_hash_")
394                        .unwrap_or(&result.custom_id)
395                        .to_string();
396
397                    if let Some(response_body) = result.response {
398                        if response_body.status_code == 200 {
399                            let response_json = serde_json::to_string(&response_body.body)?;
400                            updates.push((response_json, request_hash));
401                            success_count += 1;
402                        } else {
403                            log::error!(
404                                "Batch request {} failed with status {}",
405                                request_hash,
406                                response_body.status_code
407                            );
408                            let error_json = serde_json::json!({
409                                "error": {
410                                    "type": "http_error",
411                                    "status_code": response_body.status_code
412                                }
413                            })
414                            .to_string();
415                            updates.push((error_json, request_hash));
416                        }
417                    } else if let Some(error) = result.error {
418                        log::error!(
419                            "Batch request {} failed: {}: {}",
420                            request_hash,
421                            error.code,
422                            error.message
423                        );
424                        let error_json = serde_json::json!({
425                            "error": {
426                                "type": error.code,
427                                "message": error.message
428                            }
429                        })
430                        .to_string();
431                        updates.push((error_json, request_hash));
432                    }
433                }
434
435                let connection = self.connection.lock().unwrap();
436                connection.with_savepoint("batch_download", || {
437                    let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
438                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
439                    for (response_json, request_hash) in &updates {
440                        exec((response_json.as_str(), request_hash.as_str()))?;
441                    }
442                    Ok(())
443                })?;
444                log::info!("Downloaded {} successful requests", success_count);
445            }
446        }
447
448        Ok(())
449    }
450
451    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
452        const BATCH_CHUNK_SIZE: i32 = 16_000;
453        let mut all_batch_ids = Vec::new();
454        let mut total_uploaded = 0;
455
456        loop {
457            let rows: Vec<(String, String)> = {
458                let connection = self.connection.lock().unwrap();
459                let q = sql!(
460                    SELECT request_hash, request FROM openai_cache
461                    WHERE batch_id IS NULL AND response IS NULL
462                    LIMIT ?
463                );
464                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
465            };
466
467            if rows.is_empty() {
468                break;
469            }
470
471            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
472
473            let mut jsonl_content = String::new();
474            for (hash, request_str) in &rows {
475                let serializable_request: SerializableRequest =
476                    serde_json::from_str(request_str).unwrap();
477
478                let messages: Vec<RequestMessage> = serializable_request
479                    .messages
480                    .into_iter()
481                    .map(|msg| match msg.role.as_str() {
482                        "user" => RequestMessage::User {
483                            content: MessageContent::Plain(msg.content),
484                        },
485                        "assistant" => RequestMessage::Assistant {
486                            content: Some(MessageContent::Plain(msg.content)),
487                            tool_calls: Vec::new(),
488                        },
489                        "system" => RequestMessage::System {
490                            content: MessageContent::Plain(msg.content),
491                        },
492                        _ => RequestMessage::User {
493                            content: MessageContent::Plain(msg.content),
494                        },
495                    })
496                    .collect();
497
498                let request = OpenAiRequest {
499                    model: serializable_request.model,
500                    messages,
501                    stream: false,
502                    stream_options: None,
503                    max_completion_tokens: Some(serializable_request.max_tokens),
504                    stop: Vec::new(),
505                    temperature: None,
506                    tool_choice: None,
507                    parallel_tool_calls: None,
508                    tools: Vec::new(),
509                    prompt_cache_key: None,
510                    reasoning_effort: None,
511                };
512
513                let custom_id = format!("req_hash_{}", hash);
514                let batch_item = batches::BatchRequestItem::new(custom_id, request);
515                let line = batch_item
516                    .to_jsonl_line()
517                    .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
518                jsonl_content.push_str(&line);
519                jsonl_content.push('\n');
520            }
521
522            let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
523            let file_obj = batches::upload_batch_file(
524                self.http_client.as_ref(),
525                OPEN_AI_API_URL,
526                &self.api_key,
527                &filename,
528                jsonl_content.into_bytes(),
529            )
530            .await
531            .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
532
533            let batch = batches::create_batch(
534                self.http_client.as_ref(),
535                OPEN_AI_API_URL,
536                &self.api_key,
537                batches::CreateBatchRequest::new(file_obj.id),
538            )
539            .await
540            .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
541
542            {
543                let connection = self.connection.lock().unwrap();
544                connection.with_savepoint("batch_upload", || {
545                    let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
546                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
547                    for hash in &request_hashes {
548                        exec((batch.id.as_str(), hash.as_str()))?;
549                    }
550                    Ok(())
551                })?;
552            }
553
554            let batch_len = rows.len();
555            total_uploaded += batch_len;
556            log::info!(
557                "Uploaded batch {} with {} requests ({} total)",
558                batch.id,
559                batch_len,
560                total_uploaded
561            );
562
563            all_batch_ids.push(batch.id);
564        }
565
566        if !all_batch_ids.is_empty() {
567            log::info!(
568                "Finished uploading {} batches with {} total requests",
569                all_batch_ids.len(),
570                total_uploaded
571            );
572        }
573
574        Ok(all_batch_ids)
575    }
576
577    fn request_hash(
578        model: &str,
579        max_tokens: u64,
580        messages: &[RequestMessage],
581        seed: Option<usize>,
582    ) -> String {
583        let mut hasher = std::hash::DefaultHasher::new();
584        "openai".hash(&mut hasher);
585        model.hash(&mut hasher);
586        max_tokens.hash(&mut hasher);
587        for msg in messages {
588            message_content_to_string(msg).hash(&mut hasher);
589        }
590        if let Some(seed) = seed {
591            seed.hash(&mut hasher);
592        }
593        let request_hash = hasher.finish();
594        format!("{request_hash:016x}")
595    }
596}
597
598fn message_role_to_string(msg: &RequestMessage) -> String {
599    match msg {
600        RequestMessage::User { .. } => "user".to_string(),
601        RequestMessage::Assistant { .. } => "assistant".to_string(),
602        RequestMessage::System { .. } => "system".to_string(),
603        RequestMessage::Tool { .. } => "tool".to_string(),
604    }
605}
606
607fn message_content_to_string(msg: &RequestMessage) -> String {
608    match msg {
609        RequestMessage::User { content } => content_to_string(content),
610        RequestMessage::Assistant { content, .. } => {
611            content.as_ref().map(content_to_string).unwrap_or_default()
612        }
613        RequestMessage::System { content } => content_to_string(content),
614        RequestMessage::Tool { content, .. } => content_to_string(content),
615    }
616}
617
618fn content_to_string(content: &MessageContent) -> String {
619    match content {
620        MessageContent::Plain(text) => text.clone(),
621        MessageContent::Multipart(parts) => parts
622            .iter()
623            .filter_map(|part| match part {
624                open_ai::MessagePart::Text { text } => Some(text.clone()),
625                _ => None,
626            })
627            .collect::<Vec<String>>()
628            .join("\n"),
629    }
630}
631
632pub enum OpenAiClient {
633    Plain(PlainOpenAiClient),
634    Batch(BatchingOpenAiClient),
635    #[allow(dead_code)]
636    Dummy,
637}
638
639impl OpenAiClient {
640    pub fn plain() -> Result<Self> {
641        Ok(Self::Plain(PlainOpenAiClient::new()?))
642    }
643
644    pub fn batch(cache_path: &Path) -> Result<Self> {
645        Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
646    }
647
648    #[allow(dead_code)]
649    pub fn dummy() -> Self {
650        Self::Dummy
651    }
652
653    pub async fn generate(
654        &self,
655        model: &str,
656        max_tokens: u64,
657        messages: Vec<RequestMessage>,
658        seed: Option<usize>,
659        cache_only: bool,
660    ) -> Result<Option<OpenAiResponse>> {
661        match self {
662            OpenAiClient::Plain(plain_client) => plain_client
663                .generate(model, max_tokens, messages)
664                .await
665                .map(Some),
666            OpenAiClient::Batch(batching_client) => {
667                batching_client
668                    .generate(model, max_tokens, messages, seed, cache_only)
669                    .await
670            }
671            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
672        }
673    }
674
675    pub async fn sync_batches(&self) -> Result<()> {
676        match self {
677            OpenAiClient::Plain(_) => Ok(()),
678            OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
679            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
680        }
681    }
682
683    pub fn pending_batch_count(&self) -> Result<usize> {
684        match self {
685            OpenAiClient::Plain(_) => Ok(0),
686            OpenAiClient::Batch(batching_client) => batching_client.pending_batch_count(),
687            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
688        }
689    }
690
691    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
692        match self {
693            OpenAiClient::Plain(_) => {
694                anyhow::bail!("Import batches is only supported with batching client")
695            }
696            OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
697            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
698        }
699    }
700}