openai_client.rs

  1use anyhow::Result;
  2use http_client::HttpClient;
  3use indoc::indoc;
  4use open_ai::{
  5    MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
  6    Response as OpenAiResponse, batches, non_streaming_completion,
  7};
  8use reqwest_client::ReqwestClient;
  9use sqlez::bindable::Bind;
 10use sqlez::bindable::StaticColumnCount;
 11use sqlez_macros::sql;
 12use std::hash::Hash;
 13use std::hash::Hasher;
 14use std::path::Path;
 15use std::sync::{Arc, Mutex};
 16
 17pub struct PlainOpenAiClient {
 18    pub http_client: Arc<dyn HttpClient>,
 19    pub api_key: String,
 20}
 21
 22impl PlainOpenAiClient {
 23    pub fn new() -> Result<Self> {
 24        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 25        let api_key = std::env::var("OPENAI_API_KEY")
 26            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
 27        Ok(Self {
 28            http_client,
 29            api_key,
 30        })
 31    }
 32
 33    pub async fn generate(
 34        &self,
 35        model: &str,
 36        max_tokens: u64,
 37        messages: Vec<RequestMessage>,
 38    ) -> Result<OpenAiResponse> {
 39        let request = OpenAiRequest {
 40            model: model.to_string(),
 41            messages,
 42            stream: false,
 43            max_completion_tokens: Some(max_tokens),
 44            stop: Vec::new(),
 45            temperature: None,
 46            tool_choice: None,
 47            parallel_tool_calls: None,
 48            tools: Vec::new(),
 49            prompt_cache_key: None,
 50            reasoning_effort: None,
 51        };
 52
 53        let response = non_streaming_completion(
 54            self.http_client.as_ref(),
 55            OPEN_AI_API_URL,
 56            &self.api_key,
 57            request,
 58        )
 59        .await
 60        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 61
 62        Ok(response)
 63    }
 64}
 65
 66pub struct BatchingOpenAiClient {
 67    connection: Mutex<sqlez::connection::Connection>,
 68    http_client: Arc<dyn HttpClient>,
 69    api_key: String,
 70}
 71
 72struct CacheRow {
 73    request_hash: String,
 74    request: Option<String>,
 75    response: Option<String>,
 76    batch_id: Option<String>,
 77}
 78
 79impl StaticColumnCount for CacheRow {
 80    fn column_count() -> usize {
 81        4
 82    }
 83}
 84
 85impl Bind for CacheRow {
 86    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 87        let next_index = statement.bind(&self.request_hash, start_index)?;
 88        let next_index = statement.bind(&self.request, next_index)?;
 89        let next_index = statement.bind(&self.response, next_index)?;
 90        let next_index = statement.bind(&self.batch_id, next_index)?;
 91        Ok(next_index)
 92    }
 93}
 94
 95#[derive(serde::Serialize, serde::Deserialize)]
 96struct SerializableRequest {
 97    model: String,
 98    max_tokens: u64,
 99    messages: Vec<SerializableMessage>,
100}
101
102#[derive(serde::Serialize, serde::Deserialize)]
103struct SerializableMessage {
104    role: String,
105    content: String,
106}
107
108impl BatchingOpenAiClient {
109    fn new(cache_path: &Path) -> Result<Self> {
110        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
111        let api_key = std::env::var("OPENAI_API_KEY")
112            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
113
114        let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
115        let mut statement = sqlez::statement::Statement::prepare(
116            &connection,
117            indoc! {"
118                CREATE TABLE IF NOT EXISTS openai_cache (
119                    request_hash TEXT PRIMARY KEY,
120                    request TEXT,
121                    response TEXT,
122                    batch_id TEXT
123                );
124                "},
125        )?;
126        statement.exec()?;
127        drop(statement);
128
129        Ok(Self {
130            connection: Mutex::new(connection),
131            http_client,
132            api_key,
133        })
134    }
135
136    pub fn lookup(
137        &self,
138        model: &str,
139        max_tokens: u64,
140        messages: &[RequestMessage],
141        seed: Option<usize>,
142    ) -> Result<Option<OpenAiResponse>> {
143        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
144        let connection = self.connection.lock().unwrap();
145        let response: Vec<String> = connection.select_bound(
146            &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
147        )?(request_hash_str.as_str())?;
148        Ok(response
149            .into_iter()
150            .next()
151            .and_then(|text| serde_json::from_str(&text).ok()))
152    }
153
154    pub fn mark_for_batch(
155        &self,
156        model: &str,
157        max_tokens: u64,
158        messages: &[RequestMessage],
159        seed: Option<usize>,
160    ) -> Result<()> {
161        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
162
163        let serializable_messages: Vec<SerializableMessage> = messages
164            .iter()
165            .map(|msg| SerializableMessage {
166                role: message_role_to_string(msg),
167                content: message_content_to_string(msg),
168            })
169            .collect();
170
171        let serializable_request = SerializableRequest {
172            model: model.to_string(),
173            max_tokens,
174            messages: serializable_messages,
175        };
176
177        let request = Some(serde_json::to_string(&serializable_request)?);
178        let cache_row = CacheRow {
179            request_hash,
180            request,
181            response: None,
182            batch_id: None,
183        };
184        let connection = self.connection.lock().unwrap();
185        connection.exec_bound::<CacheRow>(sql!(
186            INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
187            cache_row,
188        )
189    }
190
191    async fn generate(
192        &self,
193        model: &str,
194        max_tokens: u64,
195        messages: Vec<RequestMessage>,
196        seed: Option<usize>,
197        cache_only: bool,
198    ) -> Result<Option<OpenAiResponse>> {
199        let response = self.lookup(model, max_tokens, &messages, seed)?;
200        if let Some(response) = response {
201            return Ok(Some(response));
202        }
203
204        if !cache_only {
205            self.mark_for_batch(model, max_tokens, &messages, seed)?;
206        }
207
208        Ok(None)
209    }
210
211    async fn sync_batches(&self) -> Result<()> {
212        let _batch_ids = self.upload_pending_requests().await?;
213        self.download_finished_batches().await
214    }
215
216    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
217        for batch_id in batch_ids {
218            log::info!("Importing OpenAI batch {}", batch_id);
219
220            let batch_status = batches::retrieve_batch(
221                self.http_client.as_ref(),
222                OPEN_AI_API_URL,
223                &self.api_key,
224                batch_id,
225            )
226            .await
227            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
228
229            log::info!("Batch {} status: {}", batch_id, batch_status.status);
230
231            if batch_status.status != "completed" {
232                log::warn!(
233                    "Batch {} is not completed (status: {}), skipping",
234                    batch_id,
235                    batch_status.status
236                );
237                continue;
238            }
239
240            let output_file_id = batch_status.output_file_id.ok_or_else(|| {
241                anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
242            })?;
243
244            let results_content = batches::download_file(
245                self.http_client.as_ref(),
246                OPEN_AI_API_URL,
247                &self.api_key,
248                &output_file_id,
249            )
250            .await
251            .map_err(|e| {
252                anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
253            })?;
254
255            let results = batches::parse_batch_output(&results_content)
256                .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
257
258            let mut updates: Vec<(String, String, String)> = Vec::new();
259            let mut success_count = 0;
260            let mut error_count = 0;
261
262            for result in results {
263                let request_hash = result
264                    .custom_id
265                    .strip_prefix("req_hash_")
266                    .unwrap_or(&result.custom_id)
267                    .to_string();
268
269                if let Some(response_body) = result.response {
270                    if response_body.status_code == 200 {
271                        let response_json = serde_json::to_string(&response_body.body)?;
272                        updates.push((request_hash, response_json, batch_id.clone()));
273                        success_count += 1;
274                    } else {
275                        log::error!(
276                            "Batch request {} failed with status {}",
277                            request_hash,
278                            response_body.status_code
279                        );
280                        let error_json = serde_json::json!({
281                            "error": {
282                                "type": "http_error",
283                                "status_code": response_body.status_code
284                            }
285                        })
286                        .to_string();
287                        updates.push((request_hash, error_json, batch_id.clone()));
288                        error_count += 1;
289                    }
290                } else if let Some(error) = result.error {
291                    log::error!(
292                        "Batch request {} failed: {}: {}",
293                        request_hash,
294                        error.code,
295                        error.message
296                    );
297                    let error_json = serde_json::json!({
298                        "error": {
299                            "type": error.code,
300                            "message": error.message
301                        }
302                    })
303                    .to_string();
304                    updates.push((request_hash, error_json, batch_id.clone()));
305                    error_count += 1;
306                }
307            }
308
309            let connection = self.connection.lock().unwrap();
310            connection.with_savepoint("batch_import", || {
311                let q = sql!(
312                    INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
313                    VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
314                );
315                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
316                for (request_hash, response_json, batch_id) in &updates {
317                    exec((
318                        request_hash.as_str(),
319                        request_hash.as_str(),
320                        response_json.as_str(),
321                        batch_id.as_str(),
322                    ))?;
323                }
324                Ok(())
325            })?;
326
327            log::info!(
328                "Imported batch {}: {} successful, {} errors",
329                batch_id,
330                success_count,
331                error_count
332            );
333        }
334
335        Ok(())
336    }
337
338    async fn download_finished_batches(&self) -> Result<()> {
339        let batch_ids: Vec<String> = {
340            let connection = self.connection.lock().unwrap();
341            let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
342            connection.select(q)?()?
343        };
344
345        for batch_id in &batch_ids {
346            let batch_status = batches::retrieve_batch(
347                self.http_client.as_ref(),
348                OPEN_AI_API_URL,
349                &self.api_key,
350                batch_id,
351            )
352            .await
353            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
354
355            log::info!("Batch {} status: {}", batch_id, batch_status.status);
356
357            if batch_status.status == "completed" {
358                let output_file_id = match batch_status.output_file_id {
359                    Some(id) => id,
360                    None => {
361                        log::warn!("Batch {} completed but has no output file", batch_id);
362                        continue;
363                    }
364                };
365
366                let results_content = batches::download_file(
367                    self.http_client.as_ref(),
368                    OPEN_AI_API_URL,
369                    &self.api_key,
370                    &output_file_id,
371                )
372                .await
373                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
374
375                let results = batches::parse_batch_output(&results_content)
376                    .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
377
378                let mut updates: Vec<(String, String)> = Vec::new();
379                let mut success_count = 0;
380
381                for result in results {
382                    let request_hash = result
383                        .custom_id
384                        .strip_prefix("req_hash_")
385                        .unwrap_or(&result.custom_id)
386                        .to_string();
387
388                    if let Some(response_body) = result.response {
389                        if response_body.status_code == 200 {
390                            let response_json = serde_json::to_string(&response_body.body)?;
391                            updates.push((response_json, request_hash));
392                            success_count += 1;
393                        } else {
394                            log::error!(
395                                "Batch request {} failed with status {}",
396                                request_hash,
397                                response_body.status_code
398                            );
399                            let error_json = serde_json::json!({
400                                "error": {
401                                    "type": "http_error",
402                                    "status_code": response_body.status_code
403                                }
404                            })
405                            .to_string();
406                            updates.push((error_json, request_hash));
407                        }
408                    } else if let Some(error) = result.error {
409                        log::error!(
410                            "Batch request {} failed: {}: {}",
411                            request_hash,
412                            error.code,
413                            error.message
414                        );
415                        let error_json = serde_json::json!({
416                            "error": {
417                                "type": error.code,
418                                "message": error.message
419                            }
420                        })
421                        .to_string();
422                        updates.push((error_json, request_hash));
423                    }
424                }
425
426                let connection = self.connection.lock().unwrap();
427                connection.with_savepoint("batch_download", || {
428                    let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
429                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
430                    for (response_json, request_hash) in &updates {
431                        exec((response_json.as_str(), request_hash.as_str()))?;
432                    }
433                    Ok(())
434                })?;
435                log::info!("Downloaded {} successful requests", success_count);
436            }
437        }
438
439        Ok(())
440    }
441
442    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
443        const BATCH_CHUNK_SIZE: i32 = 16_000;
444        let mut all_batch_ids = Vec::new();
445        let mut total_uploaded = 0;
446
447        loop {
448            let rows: Vec<(String, String)> = {
449                let connection = self.connection.lock().unwrap();
450                let q = sql!(
451                    SELECT request_hash, request FROM openai_cache
452                    WHERE batch_id IS NULL AND response IS NULL
453                    LIMIT ?
454                );
455                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
456            };
457
458            if rows.is_empty() {
459                break;
460            }
461
462            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
463
464            let mut jsonl_content = String::new();
465            for (hash, request_str) in &rows {
466                let serializable_request: SerializableRequest =
467                    serde_json::from_str(request_str).unwrap();
468
469                let messages: Vec<RequestMessage> = serializable_request
470                    .messages
471                    .into_iter()
472                    .map(|msg| match msg.role.as_str() {
473                        "user" => RequestMessage::User {
474                            content: MessageContent::Plain(msg.content),
475                        },
476                        "assistant" => RequestMessage::Assistant {
477                            content: Some(MessageContent::Plain(msg.content)),
478                            tool_calls: Vec::new(),
479                        },
480                        "system" => RequestMessage::System {
481                            content: MessageContent::Plain(msg.content),
482                        },
483                        _ => RequestMessage::User {
484                            content: MessageContent::Plain(msg.content),
485                        },
486                    })
487                    .collect();
488
489                let request = OpenAiRequest {
490                    model: serializable_request.model,
491                    messages,
492                    stream: false,
493                    max_completion_tokens: Some(serializable_request.max_tokens),
494                    stop: Vec::new(),
495                    temperature: None,
496                    tool_choice: None,
497                    parallel_tool_calls: None,
498                    tools: Vec::new(),
499                    prompt_cache_key: None,
500                    reasoning_effort: None,
501                };
502
503                let custom_id = format!("req_hash_{}", hash);
504                let batch_item = batches::BatchRequestItem::new(custom_id, request);
505                let line = batch_item
506                    .to_jsonl_line()
507                    .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
508                jsonl_content.push_str(&line);
509                jsonl_content.push('\n');
510            }
511
512            let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
513            let file_obj = batches::upload_batch_file(
514                self.http_client.as_ref(),
515                OPEN_AI_API_URL,
516                &self.api_key,
517                &filename,
518                jsonl_content.into_bytes(),
519            )
520            .await
521            .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
522
523            let batch = batches::create_batch(
524                self.http_client.as_ref(),
525                OPEN_AI_API_URL,
526                &self.api_key,
527                batches::CreateBatchRequest::new(file_obj.id),
528            )
529            .await
530            .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
531
532            {
533                let connection = self.connection.lock().unwrap();
534                connection.with_savepoint("batch_upload", || {
535                    let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
536                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
537                    for hash in &request_hashes {
538                        exec((batch.id.as_str(), hash.as_str()))?;
539                    }
540                    Ok(())
541                })?;
542            }
543
544            let batch_len = rows.len();
545            total_uploaded += batch_len;
546            log::info!(
547                "Uploaded batch {} with {} requests ({} total)",
548                batch.id,
549                batch_len,
550                total_uploaded
551            );
552
553            all_batch_ids.push(batch.id);
554        }
555
556        if !all_batch_ids.is_empty() {
557            log::info!(
558                "Finished uploading {} batches with {} total requests",
559                all_batch_ids.len(),
560                total_uploaded
561            );
562        }
563
564        Ok(all_batch_ids)
565    }
566
567    fn request_hash(
568        model: &str,
569        max_tokens: u64,
570        messages: &[RequestMessage],
571        seed: Option<usize>,
572    ) -> String {
573        let mut hasher = std::hash::DefaultHasher::new();
574        "openai".hash(&mut hasher);
575        model.hash(&mut hasher);
576        max_tokens.hash(&mut hasher);
577        for msg in messages {
578            message_content_to_string(msg).hash(&mut hasher);
579        }
580        if let Some(seed) = seed {
581            seed.hash(&mut hasher);
582        }
583        let request_hash = hasher.finish();
584        format!("{request_hash:016x}")
585    }
586}
587
588fn message_role_to_string(msg: &RequestMessage) -> String {
589    match msg {
590        RequestMessage::User { .. } => "user".to_string(),
591        RequestMessage::Assistant { .. } => "assistant".to_string(),
592        RequestMessage::System { .. } => "system".to_string(),
593        RequestMessage::Tool { .. } => "tool".to_string(),
594    }
595}
596
597fn message_content_to_string(msg: &RequestMessage) -> String {
598    match msg {
599        RequestMessage::User { content } => content_to_string(content),
600        RequestMessage::Assistant { content, .. } => {
601            content.as_ref().map(content_to_string).unwrap_or_default()
602        }
603        RequestMessage::System { content } => content_to_string(content),
604        RequestMessage::Tool { content, .. } => content_to_string(content),
605    }
606}
607
608fn content_to_string(content: &MessageContent) -> String {
609    match content {
610        MessageContent::Plain(text) => text.clone(),
611        MessageContent::Multipart(parts) => parts
612            .iter()
613            .filter_map(|part| match part {
614                open_ai::MessagePart::Text { text } => Some(text.clone()),
615                _ => None,
616            })
617            .collect::<Vec<String>>()
618            .join("\n"),
619    }
620}
621
622pub enum OpenAiClient {
623    Plain(PlainOpenAiClient),
624    Batch(BatchingOpenAiClient),
625    #[allow(dead_code)]
626    Dummy,
627}
628
629impl OpenAiClient {
630    pub fn plain() -> Result<Self> {
631        Ok(Self::Plain(PlainOpenAiClient::new()?))
632    }
633
634    pub fn batch(cache_path: &Path) -> Result<Self> {
635        Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
636    }
637
638    #[allow(dead_code)]
639    pub fn dummy() -> Self {
640        Self::Dummy
641    }
642
643    pub async fn generate(
644        &self,
645        model: &str,
646        max_tokens: u64,
647        messages: Vec<RequestMessage>,
648        seed: Option<usize>,
649        cache_only: bool,
650    ) -> Result<Option<OpenAiResponse>> {
651        match self {
652            OpenAiClient::Plain(plain_client) => plain_client
653                .generate(model, max_tokens, messages)
654                .await
655                .map(Some),
656            OpenAiClient::Batch(batching_client) => {
657                batching_client
658                    .generate(model, max_tokens, messages, seed, cache_only)
659                    .await
660            }
661            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
662        }
663    }
664
665    pub async fn sync_batches(&self) -> Result<()> {
666        match self {
667            OpenAiClient::Plain(_) => Ok(()),
668            OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
669            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
670        }
671    }
672
673    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
674        match self {
675            OpenAiClient::Plain(_) => {
676                anyhow::bail!("Import batches is only supported with batching client")
677            }
678            OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
679            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
680        }
681    }
682}