openai_client.rs

  1use anyhow::Result;
  2use http_client::HttpClient;
  3use indoc::indoc;
  4use open_ai::{
  5    MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
  6    Response as OpenAiResponse, batches, non_streaming_completion,
  7};
  8use reqwest_client::ReqwestClient;
  9use sqlez::bindable::Bind;
 10use sqlez::bindable::StaticColumnCount;
 11use sqlez_macros::sql;
 12use std::hash::Hash;
 13use std::hash::Hasher;
 14use std::path::Path;
 15use std::sync::{Arc, Mutex};
 16
 17pub struct PlainOpenAiClient {
 18    pub http_client: Arc<dyn HttpClient>,
 19    pub api_key: String,
 20}
 21
 22impl PlainOpenAiClient {
 23    pub fn new() -> Result<Self> {
 24        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 25        let api_key = std::env::var("OPENAI_API_KEY")
 26            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
 27        Ok(Self {
 28            http_client,
 29            api_key,
 30        })
 31    }
 32
 33    pub async fn generate(
 34        &self,
 35        model: &str,
 36        max_tokens: u64,
 37        messages: Vec<RequestMessage>,
 38    ) -> Result<OpenAiResponse> {
 39        let request = OpenAiRequest {
 40            model: model.to_string(),
 41            messages,
 42            stream: false,
 43            max_completion_tokens: Some(max_tokens),
 44            stop: Vec::new(),
 45            temperature: None,
 46            tool_choice: None,
 47            parallel_tool_calls: None,
 48            tools: Vec::new(),
 49            prompt_cache_key: None,
 50            reasoning_effort: None,
 51        };
 52
 53        let response = non_streaming_completion(
 54            self.http_client.as_ref(),
 55            OPEN_AI_API_URL,
 56            &self.api_key,
 57            request,
 58        )
 59        .await
 60        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 61
 62        Ok(response)
 63    }
 64}
 65
 66pub struct BatchingOpenAiClient {
 67    connection: Mutex<sqlez::connection::Connection>,
 68    http_client: Arc<dyn HttpClient>,
 69    api_key: String,
 70}
 71
 72struct CacheRow {
 73    request_hash: String,
 74    request: Option<String>,
 75    response: Option<String>,
 76    batch_id: Option<String>,
 77}
 78
 79impl StaticColumnCount for CacheRow {
 80    fn column_count() -> usize {
 81        4
 82    }
 83}
 84
 85impl Bind for CacheRow {
 86    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 87        let next_index = statement.bind(&self.request_hash, start_index)?;
 88        let next_index = statement.bind(&self.request, next_index)?;
 89        let next_index = statement.bind(&self.response, next_index)?;
 90        let next_index = statement.bind(&self.batch_id, next_index)?;
 91        Ok(next_index)
 92    }
 93}
 94
 95#[derive(serde::Serialize, serde::Deserialize)]
 96struct SerializableRequest {
 97    model: String,
 98    max_tokens: u64,
 99    messages: Vec<SerializableMessage>,
100}
101
102#[derive(serde::Serialize, serde::Deserialize)]
103struct SerializableMessage {
104    role: String,
105    content: String,
106}
107
108impl BatchingOpenAiClient {
109    fn new(cache_path: &Path) -> Result<Self> {
110        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
111        let api_key = std::env::var("OPENAI_API_KEY")
112            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
113
114        let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
115        let mut statement = sqlez::statement::Statement::prepare(
116            &connection,
117            indoc! {"
118                CREATE TABLE IF NOT EXISTS openai_cache (
119                    request_hash TEXT PRIMARY KEY,
120                    request TEXT,
121                    response TEXT,
122                    batch_id TEXT
123                );
124                "},
125        )?;
126        statement.exec()?;
127        drop(statement);
128
129        Ok(Self {
130            connection: Mutex::new(connection),
131            http_client,
132            api_key,
133        })
134    }
135
136    pub fn lookup(
137        &self,
138        model: &str,
139        max_tokens: u64,
140        messages: &[RequestMessage],
141    ) -> Result<Option<OpenAiResponse>> {
142        let request_hash_str = Self::request_hash(model, max_tokens, messages);
143        let connection = self.connection.lock().unwrap();
144        let response: Vec<String> = connection.select_bound(
145            &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
146        )?(request_hash_str.as_str())?;
147        Ok(response
148            .into_iter()
149            .next()
150            .and_then(|text| serde_json::from_str(&text).ok()))
151    }
152
153    pub fn mark_for_batch(
154        &self,
155        model: &str,
156        max_tokens: u64,
157        messages: &[RequestMessage],
158    ) -> Result<()> {
159        let request_hash = Self::request_hash(model, max_tokens, messages);
160
161        let serializable_messages: Vec<SerializableMessage> = messages
162            .iter()
163            .map(|msg| SerializableMessage {
164                role: message_role_to_string(msg),
165                content: message_content_to_string(msg),
166            })
167            .collect();
168
169        let serializable_request = SerializableRequest {
170            model: model.to_string(),
171            max_tokens,
172            messages: serializable_messages,
173        };
174
175        let request = Some(serde_json::to_string(&serializable_request)?);
176        let cache_row = CacheRow {
177            request_hash,
178            request,
179            response: None,
180            batch_id: None,
181        };
182        let connection = self.connection.lock().unwrap();
183        connection.exec_bound::<CacheRow>(sql!(
184            INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
185            cache_row,
186        )
187    }
188
189    async fn generate(
190        &self,
191        model: &str,
192        max_tokens: u64,
193        messages: Vec<RequestMessage>,
194    ) -> Result<Option<OpenAiResponse>> {
195        let response = self.lookup(model, max_tokens, &messages)?;
196        if let Some(response) = response {
197            return Ok(Some(response));
198        }
199
200        self.mark_for_batch(model, max_tokens, &messages)?;
201
202        Ok(None)
203    }
204
205    async fn sync_batches(&self) -> Result<()> {
206        let _batch_ids = self.upload_pending_requests().await?;
207        self.download_finished_batches().await
208    }
209
210    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
211        for batch_id in batch_ids {
212            log::info!("Importing OpenAI batch {}", batch_id);
213
214            let batch_status = batches::retrieve_batch(
215                self.http_client.as_ref(),
216                OPEN_AI_API_URL,
217                &self.api_key,
218                batch_id,
219            )
220            .await
221            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
222
223            log::info!("Batch {} status: {}", batch_id, batch_status.status);
224
225            if batch_status.status != "completed" {
226                log::warn!(
227                    "Batch {} is not completed (status: {}), skipping",
228                    batch_id,
229                    batch_status.status
230                );
231                continue;
232            }
233
234            let output_file_id = batch_status.output_file_id.ok_or_else(|| {
235                anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
236            })?;
237
238            let results_content = batches::download_file(
239                self.http_client.as_ref(),
240                OPEN_AI_API_URL,
241                &self.api_key,
242                &output_file_id,
243            )
244            .await
245            .map_err(|e| {
246                anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
247            })?;
248
249            let results = batches::parse_batch_output(&results_content)
250                .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
251
252            let mut updates: Vec<(String, String, String)> = Vec::new();
253            let mut success_count = 0;
254            let mut error_count = 0;
255
256            for result in results {
257                let request_hash = result
258                    .custom_id
259                    .strip_prefix("req_hash_")
260                    .unwrap_or(&result.custom_id)
261                    .to_string();
262
263                if let Some(response_body) = result.response {
264                    if response_body.status_code == 200 {
265                        let response_json = serde_json::to_string(&response_body.body)?;
266                        updates.push((request_hash, response_json, batch_id.clone()));
267                        success_count += 1;
268                    } else {
269                        log::error!(
270                            "Batch request {} failed with status {}",
271                            request_hash,
272                            response_body.status_code
273                        );
274                        let error_json = serde_json::json!({
275                            "error": {
276                                "type": "http_error",
277                                "status_code": response_body.status_code
278                            }
279                        })
280                        .to_string();
281                        updates.push((request_hash, error_json, batch_id.clone()));
282                        error_count += 1;
283                    }
284                } else if let Some(error) = result.error {
285                    log::error!(
286                        "Batch request {} failed: {}: {}",
287                        request_hash,
288                        error.code,
289                        error.message
290                    );
291                    let error_json = serde_json::json!({
292                        "error": {
293                            "type": error.code,
294                            "message": error.message
295                        }
296                    })
297                    .to_string();
298                    updates.push((request_hash, error_json, batch_id.clone()));
299                    error_count += 1;
300                }
301            }
302
303            let connection = self.connection.lock().unwrap();
304            connection.with_savepoint("batch_import", || {
305                let q = sql!(
306                    INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
307                    VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
308                );
309                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
310                for (request_hash, response_json, batch_id) in &updates {
311                    exec((
312                        request_hash.as_str(),
313                        request_hash.as_str(),
314                        response_json.as_str(),
315                        batch_id.as_str(),
316                    ))?;
317                }
318                Ok(())
319            })?;
320
321            log::info!(
322                "Imported batch {}: {} successful, {} errors",
323                batch_id,
324                success_count,
325                error_count
326            );
327        }
328
329        Ok(())
330    }
331
332    async fn download_finished_batches(&self) -> Result<()> {
333        let batch_ids: Vec<String> = {
334            let connection = self.connection.lock().unwrap();
335            let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
336            connection.select(q)?()?
337        };
338
339        for batch_id in &batch_ids {
340            let batch_status = batches::retrieve_batch(
341                self.http_client.as_ref(),
342                OPEN_AI_API_URL,
343                &self.api_key,
344                batch_id,
345            )
346            .await
347            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
348
349            log::info!("Batch {} status: {}", batch_id, batch_status.status);
350
351            if batch_status.status == "completed" {
352                let output_file_id = match batch_status.output_file_id {
353                    Some(id) => id,
354                    None => {
355                        log::warn!("Batch {} completed but has no output file", batch_id);
356                        continue;
357                    }
358                };
359
360                let results_content = batches::download_file(
361                    self.http_client.as_ref(),
362                    OPEN_AI_API_URL,
363                    &self.api_key,
364                    &output_file_id,
365                )
366                .await
367                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
368
369                let results = batches::parse_batch_output(&results_content)
370                    .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
371
372                let mut updates: Vec<(String, String)> = Vec::new();
373                let mut success_count = 0;
374
375                for result in results {
376                    let request_hash = result
377                        .custom_id
378                        .strip_prefix("req_hash_")
379                        .unwrap_or(&result.custom_id)
380                        .to_string();
381
382                    if let Some(response_body) = result.response {
383                        if response_body.status_code == 200 {
384                            let response_json = serde_json::to_string(&response_body.body)?;
385                            updates.push((response_json, request_hash));
386                            success_count += 1;
387                        } else {
388                            log::error!(
389                                "Batch request {} failed with status {}",
390                                request_hash,
391                                response_body.status_code
392                            );
393                            let error_json = serde_json::json!({
394                                "error": {
395                                    "type": "http_error",
396                                    "status_code": response_body.status_code
397                                }
398                            })
399                            .to_string();
400                            updates.push((error_json, request_hash));
401                        }
402                    } else if let Some(error) = result.error {
403                        log::error!(
404                            "Batch request {} failed: {}: {}",
405                            request_hash,
406                            error.code,
407                            error.message
408                        );
409                        let error_json = serde_json::json!({
410                            "error": {
411                                "type": error.code,
412                                "message": error.message
413                            }
414                        })
415                        .to_string();
416                        updates.push((error_json, request_hash));
417                    }
418                }
419
420                let connection = self.connection.lock().unwrap();
421                connection.with_savepoint("batch_download", || {
422                    let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
423                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
424                    for (response_json, request_hash) in &updates {
425                        exec((response_json.as_str(), request_hash.as_str()))?;
426                    }
427                    Ok(())
428                })?;
429                log::info!("Downloaded {} successful requests", success_count);
430            }
431        }
432
433        Ok(())
434    }
435
436    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
437        const BATCH_CHUNK_SIZE: i32 = 16_000;
438        let mut all_batch_ids = Vec::new();
439        let mut total_uploaded = 0;
440
441        loop {
442            let rows: Vec<(String, String)> = {
443                let connection = self.connection.lock().unwrap();
444                let q = sql!(
445                    SELECT request_hash, request FROM openai_cache
446                    WHERE batch_id IS NULL AND response IS NULL
447                    LIMIT ?
448                );
449                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
450            };
451
452            if rows.is_empty() {
453                break;
454            }
455
456            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
457
458            let mut jsonl_content = String::new();
459            for (hash, request_str) in &rows {
460                let serializable_request: SerializableRequest =
461                    serde_json::from_str(request_str).unwrap();
462
463                let messages: Vec<RequestMessage> = serializable_request
464                    .messages
465                    .into_iter()
466                    .map(|msg| match msg.role.as_str() {
467                        "user" => RequestMessage::User {
468                            content: MessageContent::Plain(msg.content),
469                        },
470                        "assistant" => RequestMessage::Assistant {
471                            content: Some(MessageContent::Plain(msg.content)),
472                            tool_calls: Vec::new(),
473                        },
474                        "system" => RequestMessage::System {
475                            content: MessageContent::Plain(msg.content),
476                        },
477                        _ => RequestMessage::User {
478                            content: MessageContent::Plain(msg.content),
479                        },
480                    })
481                    .collect();
482
483                let request = OpenAiRequest {
484                    model: serializable_request.model,
485                    messages,
486                    stream: false,
487                    max_completion_tokens: Some(serializable_request.max_tokens),
488                    stop: Vec::new(),
489                    temperature: None,
490                    tool_choice: None,
491                    parallel_tool_calls: None,
492                    tools: Vec::new(),
493                    prompt_cache_key: None,
494                    reasoning_effort: None,
495                };
496
497                let custom_id = format!("req_hash_{}", hash);
498                let batch_item = batches::BatchRequestItem::new(custom_id, request);
499                let line = batch_item
500                    .to_jsonl_line()
501                    .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
502                jsonl_content.push_str(&line);
503                jsonl_content.push('\n');
504            }
505
506            let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
507            let file_obj = batches::upload_batch_file(
508                self.http_client.as_ref(),
509                OPEN_AI_API_URL,
510                &self.api_key,
511                &filename,
512                jsonl_content.into_bytes(),
513            )
514            .await
515            .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
516
517            let batch = batches::create_batch(
518                self.http_client.as_ref(),
519                OPEN_AI_API_URL,
520                &self.api_key,
521                batches::CreateBatchRequest::new(file_obj.id),
522            )
523            .await
524            .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
525
526            {
527                let connection = self.connection.lock().unwrap();
528                connection.with_savepoint("batch_upload", || {
529                    let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
530                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
531                    for hash in &request_hashes {
532                        exec((batch.id.as_str(), hash.as_str()))?;
533                    }
534                    Ok(())
535                })?;
536            }
537
538            let batch_len = rows.len();
539            total_uploaded += batch_len;
540            log::info!(
541                "Uploaded batch {} with {} requests ({} total)",
542                batch.id,
543                batch_len,
544                total_uploaded
545            );
546
547            all_batch_ids.push(batch.id);
548        }
549
550        if !all_batch_ids.is_empty() {
551            log::info!(
552                "Finished uploading {} batches with {} total requests",
553                all_batch_ids.len(),
554                total_uploaded
555            );
556        }
557
558        Ok(all_batch_ids)
559    }
560
561    fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String {
562        let mut hasher = std::hash::DefaultHasher::new();
563        "openai".hash(&mut hasher);
564        model.hash(&mut hasher);
565        max_tokens.hash(&mut hasher);
566        for msg in messages {
567            message_content_to_string(msg).hash(&mut hasher);
568        }
569        let request_hash = hasher.finish();
570        format!("{request_hash:016x}")
571    }
572}
573
574fn message_role_to_string(msg: &RequestMessage) -> String {
575    match msg {
576        RequestMessage::User { .. } => "user".to_string(),
577        RequestMessage::Assistant { .. } => "assistant".to_string(),
578        RequestMessage::System { .. } => "system".to_string(),
579        RequestMessage::Tool { .. } => "tool".to_string(),
580    }
581}
582
583fn message_content_to_string(msg: &RequestMessage) -> String {
584    match msg {
585        RequestMessage::User { content } => content_to_string(content),
586        RequestMessage::Assistant { content, .. } => {
587            content.as_ref().map(content_to_string).unwrap_or_default()
588        }
589        RequestMessage::System { content } => content_to_string(content),
590        RequestMessage::Tool { content, .. } => content_to_string(content),
591    }
592}
593
594fn content_to_string(content: &MessageContent) -> String {
595    match content {
596        MessageContent::Plain(text) => text.clone(),
597        MessageContent::Multipart(parts) => parts
598            .iter()
599            .filter_map(|part| match part {
600                open_ai::MessagePart::Text { text } => Some(text.clone()),
601                _ => None,
602            })
603            .collect::<Vec<String>>()
604            .join("\n"),
605    }
606}
607
608pub enum OpenAiClient {
609    Plain(PlainOpenAiClient),
610    Batch(BatchingOpenAiClient),
611    #[allow(dead_code)]
612    Dummy,
613}
614
615impl OpenAiClient {
616    pub fn plain() -> Result<Self> {
617        Ok(Self::Plain(PlainOpenAiClient::new()?))
618    }
619
620    pub fn batch(cache_path: &Path) -> Result<Self> {
621        Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
622    }
623
624    #[allow(dead_code)]
625    pub fn dummy() -> Self {
626        Self::Dummy
627    }
628
629    pub async fn generate(
630        &self,
631        model: &str,
632        max_tokens: u64,
633        messages: Vec<RequestMessage>,
634    ) -> Result<Option<OpenAiResponse>> {
635        match self {
636            OpenAiClient::Plain(plain_client) => plain_client
637                .generate(model, max_tokens, messages)
638                .await
639                .map(Some),
640            OpenAiClient::Batch(batching_client) => {
641                batching_client.generate(model, max_tokens, messages).await
642            }
643            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
644        }
645    }
646
647    pub async fn sync_batches(&self) -> Result<()> {
648        match self {
649            OpenAiClient::Plain(_) => Ok(()),
650            OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
651            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
652        }
653    }
654
655    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
656        match self {
657            OpenAiClient::Plain(_) => {
658                anyhow::bail!("Import batches is only supported with batching client")
659            }
660            OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
661            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
662        }
663    }
664}