openai_client.rs

  1use anyhow::Result;
  2use http_client::HttpClient;
  3use indoc::indoc;
  4use open_ai::{
  5    MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
  6    Response as OpenAiResponse, batches, non_streaming_completion,
  7};
  8use reqwest_client::ReqwestClient;
  9use sqlez::bindable::Bind;
 10use sqlez::bindable::StaticColumnCount;
 11use sqlez_macros::sql;
 12use std::hash::Hash;
 13use std::hash::Hasher;
 14use std::path::Path;
 15use std::sync::{Arc, Mutex};
 16
 17pub struct PlainOpenAiClient {
 18    pub http_client: Arc<dyn HttpClient>,
 19    pub api_key: String,
 20}
 21
 22impl PlainOpenAiClient {
 23    pub fn new() -> Result<Self> {
 24        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 25        let api_key = std::env::var("OPENAI_API_KEY")
 26            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
 27        Ok(Self {
 28            http_client,
 29            api_key,
 30        })
 31    }
 32
 33    pub async fn generate(
 34        &self,
 35        model: &str,
 36        max_tokens: u64,
 37        messages: Vec<RequestMessage>,
 38    ) -> Result<OpenAiResponse> {
 39        let request = OpenAiRequest {
 40            model: model.to_string(),
 41            messages,
 42            stream: false,
 43            max_completion_tokens: Some(max_tokens),
 44            stop: Vec::new(),
 45            temperature: None,
 46            tool_choice: None,
 47            parallel_tool_calls: None,
 48            tools: Vec::new(),
 49            prompt_cache_key: None,
 50            reasoning_effort: None,
 51        };
 52
 53        let response = non_streaming_completion(
 54            self.http_client.as_ref(),
 55            OPEN_AI_API_URL,
 56            &self.api_key,
 57            request,
 58        )
 59        .await
 60        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 61
 62        Ok(response)
 63    }
 64}
 65
 66pub struct BatchingOpenAiClient {
 67    connection: Mutex<sqlez::connection::Connection>,
 68    http_client: Arc<dyn HttpClient>,
 69    api_key: String,
 70}
 71
 72struct CacheRow {
 73    request_hash: String,
 74    request: Option<String>,
 75    response: Option<String>,
 76    batch_id: Option<String>,
 77}
 78
 79impl StaticColumnCount for CacheRow {
 80    fn column_count() -> usize {
 81        4
 82    }
 83}
 84
 85impl Bind for CacheRow {
 86    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 87        let next_index = statement.bind(&self.request_hash, start_index)?;
 88        let next_index = statement.bind(&self.request, next_index)?;
 89        let next_index = statement.bind(&self.response, next_index)?;
 90        let next_index = statement.bind(&self.batch_id, next_index)?;
 91        Ok(next_index)
 92    }
 93}
 94
 95#[derive(serde::Serialize, serde::Deserialize)]
 96struct SerializableRequest {
 97    model: String,
 98    max_tokens: u64,
 99    messages: Vec<SerializableMessage>,
100}
101
102#[derive(serde::Serialize, serde::Deserialize)]
103struct SerializableMessage {
104    role: String,
105    content: String,
106}
107
108impl BatchingOpenAiClient {
109    fn new(cache_path: &Path) -> Result<Self> {
110        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
111        let api_key = std::env::var("OPENAI_API_KEY")
112            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
113
114        let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
115        let mut statement = sqlez::statement::Statement::prepare(
116            &connection,
117            indoc! {"
118                CREATE TABLE IF NOT EXISTS openai_cache (
119                    request_hash TEXT PRIMARY KEY,
120                    request TEXT,
121                    response TEXT,
122                    batch_id TEXT
123                );
124                "},
125        )?;
126        statement.exec()?;
127        drop(statement);
128
129        Ok(Self {
130            connection: Mutex::new(connection),
131            http_client,
132            api_key,
133        })
134    }
135
136    pub fn lookup(
137        &self,
138        model: &str,
139        max_tokens: u64,
140        messages: &[RequestMessage],
141        seed: Option<usize>,
142    ) -> Result<Option<OpenAiResponse>> {
143        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
144        let connection = self.connection.lock().unwrap();
145        let response: Vec<String> = connection.select_bound(
146            &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
147        )?(request_hash_str.as_str())?;
148        Ok(response
149            .into_iter()
150            .next()
151            .and_then(|text| serde_json::from_str(&text).ok()))
152    }
153
154    pub fn mark_for_batch(
155        &self,
156        model: &str,
157        max_tokens: u64,
158        messages: &[RequestMessage],
159        seed: Option<usize>,
160    ) -> Result<()> {
161        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
162
163        let serializable_messages: Vec<SerializableMessage> = messages
164            .iter()
165            .map(|msg| SerializableMessage {
166                role: message_role_to_string(msg),
167                content: message_content_to_string(msg),
168            })
169            .collect();
170
171        let serializable_request = SerializableRequest {
172            model: model.to_string(),
173            max_tokens,
174            messages: serializable_messages,
175        };
176
177        let request = Some(serde_json::to_string(&serializable_request)?);
178        let cache_row = CacheRow {
179            request_hash,
180            request,
181            response: None,
182            batch_id: None,
183        };
184        let connection = self.connection.lock().unwrap();
185        connection.exec_bound::<CacheRow>(sql!(
186            INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
187            cache_row,
188        )
189    }
190
191    async fn generate(
192        &self,
193        model: &str,
194        max_tokens: u64,
195        messages: Vec<RequestMessage>,
196        seed: Option<usize>,
197    ) -> Result<Option<OpenAiResponse>> {
198        let response = self.lookup(model, max_tokens, &messages, seed)?;
199        if let Some(response) = response {
200            return Ok(Some(response));
201        }
202
203        self.mark_for_batch(model, max_tokens, &messages, seed)?;
204
205        Ok(None)
206    }
207
208    async fn sync_batches(&self) -> Result<()> {
209        let _batch_ids = self.upload_pending_requests().await?;
210        self.download_finished_batches().await
211    }
212
213    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
214        for batch_id in batch_ids {
215            log::info!("Importing OpenAI batch {}", batch_id);
216
217            let batch_status = batches::retrieve_batch(
218                self.http_client.as_ref(),
219                OPEN_AI_API_URL,
220                &self.api_key,
221                batch_id,
222            )
223            .await
224            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
225
226            log::info!("Batch {} status: {}", batch_id, batch_status.status);
227
228            if batch_status.status != "completed" {
229                log::warn!(
230                    "Batch {} is not completed (status: {}), skipping",
231                    batch_id,
232                    batch_status.status
233                );
234                continue;
235            }
236
237            let output_file_id = batch_status.output_file_id.ok_or_else(|| {
238                anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
239            })?;
240
241            let results_content = batches::download_file(
242                self.http_client.as_ref(),
243                OPEN_AI_API_URL,
244                &self.api_key,
245                &output_file_id,
246            )
247            .await
248            .map_err(|e| {
249                anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
250            })?;
251
252            let results = batches::parse_batch_output(&results_content)
253                .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
254
255            let mut updates: Vec<(String, String, String)> = Vec::new();
256            let mut success_count = 0;
257            let mut error_count = 0;
258
259            for result in results {
260                let request_hash = result
261                    .custom_id
262                    .strip_prefix("req_hash_")
263                    .unwrap_or(&result.custom_id)
264                    .to_string();
265
266                if let Some(response_body) = result.response {
267                    if response_body.status_code == 200 {
268                        let response_json = serde_json::to_string(&response_body.body)?;
269                        updates.push((request_hash, response_json, batch_id.clone()));
270                        success_count += 1;
271                    } else {
272                        log::error!(
273                            "Batch request {} failed with status {}",
274                            request_hash,
275                            response_body.status_code
276                        );
277                        let error_json = serde_json::json!({
278                            "error": {
279                                "type": "http_error",
280                                "status_code": response_body.status_code
281                            }
282                        })
283                        .to_string();
284                        updates.push((request_hash, error_json, batch_id.clone()));
285                        error_count += 1;
286                    }
287                } else if let Some(error) = result.error {
288                    log::error!(
289                        "Batch request {} failed: {}: {}",
290                        request_hash,
291                        error.code,
292                        error.message
293                    );
294                    let error_json = serde_json::json!({
295                        "error": {
296                            "type": error.code,
297                            "message": error.message
298                        }
299                    })
300                    .to_string();
301                    updates.push((request_hash, error_json, batch_id.clone()));
302                    error_count += 1;
303                }
304            }
305
306            let connection = self.connection.lock().unwrap();
307            connection.with_savepoint("batch_import", || {
308                let q = sql!(
309                    INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
310                    VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
311                );
312                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
313                for (request_hash, response_json, batch_id) in &updates {
314                    exec((
315                        request_hash.as_str(),
316                        request_hash.as_str(),
317                        response_json.as_str(),
318                        batch_id.as_str(),
319                    ))?;
320                }
321                Ok(())
322            })?;
323
324            log::info!(
325                "Imported batch {}: {} successful, {} errors",
326                batch_id,
327                success_count,
328                error_count
329            );
330        }
331
332        Ok(())
333    }
334
335    async fn download_finished_batches(&self) -> Result<()> {
336        let batch_ids: Vec<String> = {
337            let connection = self.connection.lock().unwrap();
338            let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
339            connection.select(q)?()?
340        };
341
342        for batch_id in &batch_ids {
343            let batch_status = batches::retrieve_batch(
344                self.http_client.as_ref(),
345                OPEN_AI_API_URL,
346                &self.api_key,
347                batch_id,
348            )
349            .await
350            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
351
352            log::info!("Batch {} status: {}", batch_id, batch_status.status);
353
354            if batch_status.status == "completed" {
355                let output_file_id = match batch_status.output_file_id {
356                    Some(id) => id,
357                    None => {
358                        log::warn!("Batch {} completed but has no output file", batch_id);
359                        continue;
360                    }
361                };
362
363                let results_content = batches::download_file(
364                    self.http_client.as_ref(),
365                    OPEN_AI_API_URL,
366                    &self.api_key,
367                    &output_file_id,
368                )
369                .await
370                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
371
372                let results = batches::parse_batch_output(&results_content)
373                    .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
374
375                let mut updates: Vec<(String, String)> = Vec::new();
376                let mut success_count = 0;
377
378                for result in results {
379                    let request_hash = result
380                        .custom_id
381                        .strip_prefix("req_hash_")
382                        .unwrap_or(&result.custom_id)
383                        .to_string();
384
385                    if let Some(response_body) = result.response {
386                        if response_body.status_code == 200 {
387                            let response_json = serde_json::to_string(&response_body.body)?;
388                            updates.push((response_json, request_hash));
389                            success_count += 1;
390                        } else {
391                            log::error!(
392                                "Batch request {} failed with status {}",
393                                request_hash,
394                                response_body.status_code
395                            );
396                            let error_json = serde_json::json!({
397                                "error": {
398                                    "type": "http_error",
399                                    "status_code": response_body.status_code
400                                }
401                            })
402                            .to_string();
403                            updates.push((error_json, request_hash));
404                        }
405                    } else if let Some(error) = result.error {
406                        log::error!(
407                            "Batch request {} failed: {}: {}",
408                            request_hash,
409                            error.code,
410                            error.message
411                        );
412                        let error_json = serde_json::json!({
413                            "error": {
414                                "type": error.code,
415                                "message": error.message
416                            }
417                        })
418                        .to_string();
419                        updates.push((error_json, request_hash));
420                    }
421                }
422
423                let connection = self.connection.lock().unwrap();
424                connection.with_savepoint("batch_download", || {
425                    let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
426                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
427                    for (response_json, request_hash) in &updates {
428                        exec((response_json.as_str(), request_hash.as_str()))?;
429                    }
430                    Ok(())
431                })?;
432                log::info!("Downloaded {} successful requests", success_count);
433            }
434        }
435
436        Ok(())
437    }
438
439    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
440        const BATCH_CHUNK_SIZE: i32 = 16_000;
441        let mut all_batch_ids = Vec::new();
442        let mut total_uploaded = 0;
443
444        loop {
445            let rows: Vec<(String, String)> = {
446                let connection = self.connection.lock().unwrap();
447                let q = sql!(
448                    SELECT request_hash, request FROM openai_cache
449                    WHERE batch_id IS NULL AND response IS NULL
450                    LIMIT ?
451                );
452                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
453            };
454
455            if rows.is_empty() {
456                break;
457            }
458
459            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
460
461            let mut jsonl_content = String::new();
462            for (hash, request_str) in &rows {
463                let serializable_request: SerializableRequest =
464                    serde_json::from_str(request_str).unwrap();
465
466                let messages: Vec<RequestMessage> = serializable_request
467                    .messages
468                    .into_iter()
469                    .map(|msg| match msg.role.as_str() {
470                        "user" => RequestMessage::User {
471                            content: MessageContent::Plain(msg.content),
472                        },
473                        "assistant" => RequestMessage::Assistant {
474                            content: Some(MessageContent::Plain(msg.content)),
475                            tool_calls: Vec::new(),
476                        },
477                        "system" => RequestMessage::System {
478                            content: MessageContent::Plain(msg.content),
479                        },
480                        _ => RequestMessage::User {
481                            content: MessageContent::Plain(msg.content),
482                        },
483                    })
484                    .collect();
485
486                let request = OpenAiRequest {
487                    model: serializable_request.model,
488                    messages,
489                    stream: false,
490                    max_completion_tokens: Some(serializable_request.max_tokens),
491                    stop: Vec::new(),
492                    temperature: None,
493                    tool_choice: None,
494                    parallel_tool_calls: None,
495                    tools: Vec::new(),
496                    prompt_cache_key: None,
497                    reasoning_effort: None,
498                };
499
500                let custom_id = format!("req_hash_{}", hash);
501                let batch_item = batches::BatchRequestItem::new(custom_id, request);
502                let line = batch_item
503                    .to_jsonl_line()
504                    .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
505                jsonl_content.push_str(&line);
506                jsonl_content.push('\n');
507            }
508
509            let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
510            let file_obj = batches::upload_batch_file(
511                self.http_client.as_ref(),
512                OPEN_AI_API_URL,
513                &self.api_key,
514                &filename,
515                jsonl_content.into_bytes(),
516            )
517            .await
518            .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
519
520            let batch = batches::create_batch(
521                self.http_client.as_ref(),
522                OPEN_AI_API_URL,
523                &self.api_key,
524                batches::CreateBatchRequest::new(file_obj.id),
525            )
526            .await
527            .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
528
529            {
530                let connection = self.connection.lock().unwrap();
531                connection.with_savepoint("batch_upload", || {
532                    let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
533                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
534                    for hash in &request_hashes {
535                        exec((batch.id.as_str(), hash.as_str()))?;
536                    }
537                    Ok(())
538                })?;
539            }
540
541            let batch_len = rows.len();
542            total_uploaded += batch_len;
543            log::info!(
544                "Uploaded batch {} with {} requests ({} total)",
545                batch.id,
546                batch_len,
547                total_uploaded
548            );
549
550            all_batch_ids.push(batch.id);
551        }
552
553        if !all_batch_ids.is_empty() {
554            log::info!(
555                "Finished uploading {} batches with {} total requests",
556                all_batch_ids.len(),
557                total_uploaded
558            );
559        }
560
561        Ok(all_batch_ids)
562    }
563
564    fn request_hash(
565        model: &str,
566        max_tokens: u64,
567        messages: &[RequestMessage],
568        seed: Option<usize>,
569    ) -> String {
570        let mut hasher = std::hash::DefaultHasher::new();
571        "openai".hash(&mut hasher);
572        model.hash(&mut hasher);
573        max_tokens.hash(&mut hasher);
574        for msg in messages {
575            message_content_to_string(msg).hash(&mut hasher);
576        }
577        if let Some(seed) = seed {
578            seed.hash(&mut hasher);
579        }
580        let request_hash = hasher.finish();
581        format!("{request_hash:016x}")
582    }
583}
584
585fn message_role_to_string(msg: &RequestMessage) -> String {
586    match msg {
587        RequestMessage::User { .. } => "user".to_string(),
588        RequestMessage::Assistant { .. } => "assistant".to_string(),
589        RequestMessage::System { .. } => "system".to_string(),
590        RequestMessage::Tool { .. } => "tool".to_string(),
591    }
592}
593
594fn message_content_to_string(msg: &RequestMessage) -> String {
595    match msg {
596        RequestMessage::User { content } => content_to_string(content),
597        RequestMessage::Assistant { content, .. } => {
598            content.as_ref().map(content_to_string).unwrap_or_default()
599        }
600        RequestMessage::System { content } => content_to_string(content),
601        RequestMessage::Tool { content, .. } => content_to_string(content),
602    }
603}
604
605fn content_to_string(content: &MessageContent) -> String {
606    match content {
607        MessageContent::Plain(text) => text.clone(),
608        MessageContent::Multipart(parts) => parts
609            .iter()
610            .filter_map(|part| match part {
611                open_ai::MessagePart::Text { text } => Some(text.clone()),
612                _ => None,
613            })
614            .collect::<Vec<String>>()
615            .join("\n"),
616    }
617}
618
619pub enum OpenAiClient {
620    Plain(PlainOpenAiClient),
621    Batch(BatchingOpenAiClient),
622    #[allow(dead_code)]
623    Dummy,
624}
625
626impl OpenAiClient {
627    pub fn plain() -> Result<Self> {
628        Ok(Self::Plain(PlainOpenAiClient::new()?))
629    }
630
631    pub fn batch(cache_path: &Path) -> Result<Self> {
632        Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
633    }
634
635    #[allow(dead_code)]
636    pub fn dummy() -> Self {
637        Self::Dummy
638    }
639
640    pub async fn generate(
641        &self,
642        model: &str,
643        max_tokens: u64,
644        messages: Vec<RequestMessage>,
645        seed: Option<usize>,
646    ) -> Result<Option<OpenAiResponse>> {
647        match self {
648            OpenAiClient::Plain(plain_client) => plain_client
649                .generate(model, max_tokens, messages)
650                .await
651                .map(Some),
652            OpenAiClient::Batch(batching_client) => {
653                batching_client
654                    .generate(model, max_tokens, messages, seed)
655                    .await
656            }
657            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
658        }
659    }
660
661    pub async fn sync_batches(&self) -> Result<()> {
662        match self {
663            OpenAiClient::Plain(_) => Ok(()),
664            OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
665            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
666        }
667    }
668
669    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
670        match self {
671            OpenAiClient::Plain(_) => {
672                anyhow::bail!("Import batches is only supported with batching client")
673            }
674            OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
675            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
676        }
677    }
678}