openai_client.rs

  1use anyhow::Result;
  2use http_client::HttpClient;
  3use indoc::indoc;
  4use open_ai::{
  5    MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
  6    Response as OpenAiResponse, batches, non_streaming_completion,
  7};
  8use reqwest_client::ReqwestClient;
  9use sqlez::bindable::Bind;
 10use sqlez::bindable::StaticColumnCount;
 11use sqlez_macros::sql;
 12use std::hash::Hash;
 13use std::hash::Hasher;
 14use std::path::Path;
 15use std::sync::{Arc, Mutex};
 16
 17pub struct PlainOpenAiClient {
 18    pub http_client: Arc<dyn HttpClient>,
 19    pub api_key: String,
 20}
 21
 22impl PlainOpenAiClient {
 23    pub fn new() -> Result<Self> {
 24        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 25        let api_key = std::env::var("OPENAI_API_KEY")
 26            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
 27        Ok(Self {
 28            http_client,
 29            api_key,
 30        })
 31    }
 32
 33    pub async fn generate(
 34        &self,
 35        model: &str,
 36        max_tokens: u64,
 37        messages: Vec<RequestMessage>,
 38    ) -> Result<OpenAiResponse> {
 39        let request = OpenAiRequest {
 40            model: model.to_string(),
 41            messages,
 42            stream: false,
 43            stream_options: None,
 44            max_completion_tokens: Some(max_tokens),
 45            stop: Vec::new(),
 46            temperature: None,
 47            tool_choice: None,
 48            parallel_tool_calls: None,
 49            tools: Vec::new(),
 50            prompt_cache_key: None,
 51            reasoning_effort: None,
 52        };
 53
 54        let response = non_streaming_completion(
 55            self.http_client.as_ref(),
 56            OPEN_AI_API_URL,
 57            &self.api_key,
 58            request,
 59        )
 60        .await
 61        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 62
 63        Ok(response)
 64    }
 65}
 66
 67pub struct BatchingOpenAiClient {
 68    connection: Mutex<sqlez::connection::Connection>,
 69    http_client: Arc<dyn HttpClient>,
 70    api_key: String,
 71}
 72
 73struct CacheRow {
 74    request_hash: String,
 75    request: Option<String>,
 76    response: Option<String>,
 77    batch_id: Option<String>,
 78}
 79
 80impl StaticColumnCount for CacheRow {
 81    fn column_count() -> usize {
 82        4
 83    }
 84}
 85
 86impl Bind for CacheRow {
 87    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 88        let next_index = statement.bind(&self.request_hash, start_index)?;
 89        let next_index = statement.bind(&self.request, next_index)?;
 90        let next_index = statement.bind(&self.response, next_index)?;
 91        let next_index = statement.bind(&self.batch_id, next_index)?;
 92        Ok(next_index)
 93    }
 94}
 95
 96#[derive(serde::Serialize, serde::Deserialize)]
 97struct SerializableRequest {
 98    model: String,
 99    max_tokens: u64,
100    messages: Vec<SerializableMessage>,
101}
102
103#[derive(serde::Serialize, serde::Deserialize)]
104struct SerializableMessage {
105    role: String,
106    content: String,
107}
108
109impl BatchingOpenAiClient {
110    fn new(cache_path: &Path) -> Result<Self> {
111        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
112        let api_key = std::env::var("OPENAI_API_KEY")
113            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
114
115        let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
116        let mut statement = sqlez::statement::Statement::prepare(
117            &connection,
118            indoc! {"
119                CREATE TABLE IF NOT EXISTS openai_cache (
120                    request_hash TEXT PRIMARY KEY,
121                    request TEXT,
122                    response TEXT,
123                    batch_id TEXT
124                );
125                "},
126        )?;
127        statement.exec()?;
128        drop(statement);
129
130        Ok(Self {
131            connection: Mutex::new(connection),
132            http_client,
133            api_key,
134        })
135    }
136
137    pub fn lookup(
138        &self,
139        model: &str,
140        max_tokens: u64,
141        messages: &[RequestMessage],
142        seed: Option<usize>,
143    ) -> Result<Option<OpenAiResponse>> {
144        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
145        let connection = self.connection.lock().unwrap();
146        let response: Vec<String> = connection.select_bound(
147            &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
148        )?(request_hash_str.as_str())?;
149        Ok(response
150            .into_iter()
151            .next()
152            .and_then(|text| serde_json::from_str(&text).ok()))
153    }
154
155    pub fn mark_for_batch(
156        &self,
157        model: &str,
158        max_tokens: u64,
159        messages: &[RequestMessage],
160        seed: Option<usize>,
161    ) -> Result<()> {
162        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
163
164        let serializable_messages: Vec<SerializableMessage> = messages
165            .iter()
166            .map(|msg| SerializableMessage {
167                role: message_role_to_string(msg),
168                content: message_content_to_string(msg),
169            })
170            .collect();
171
172        let serializable_request = SerializableRequest {
173            model: model.to_string(),
174            max_tokens,
175            messages: serializable_messages,
176        };
177
178        let request = Some(serde_json::to_string(&serializable_request)?);
179        let cache_row = CacheRow {
180            request_hash,
181            request,
182            response: None,
183            batch_id: None,
184        };
185        let connection = self.connection.lock().unwrap();
186        connection.exec_bound::<CacheRow>(sql!(
187            INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
188            cache_row,
189        )
190    }
191
192    async fn generate(
193        &self,
194        model: &str,
195        max_tokens: u64,
196        messages: Vec<RequestMessage>,
197        seed: Option<usize>,
198        cache_only: bool,
199    ) -> Result<Option<OpenAiResponse>> {
200        let response = self.lookup(model, max_tokens, &messages, seed)?;
201        if let Some(response) = response {
202            return Ok(Some(response));
203        }
204
205        if !cache_only {
206            self.mark_for_batch(model, max_tokens, &messages, seed)?;
207        }
208
209        Ok(None)
210    }
211
212    async fn sync_batches(&self) -> Result<()> {
213        let _batch_ids = self.upload_pending_requests().await?;
214        self.download_finished_batches().await
215    }
216
217    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
218        for batch_id in batch_ids {
219            log::info!("Importing OpenAI batch {}", batch_id);
220
221            let batch_status = batches::retrieve_batch(
222                self.http_client.as_ref(),
223                OPEN_AI_API_URL,
224                &self.api_key,
225                batch_id,
226            )
227            .await
228            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
229
230            log::info!("Batch {} status: {}", batch_id, batch_status.status);
231
232            if batch_status.status != "completed" {
233                log::warn!(
234                    "Batch {} is not completed (status: {}), skipping",
235                    batch_id,
236                    batch_status.status
237                );
238                continue;
239            }
240
241            let output_file_id = batch_status.output_file_id.ok_or_else(|| {
242                anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
243            })?;
244
245            let results_content = batches::download_file(
246                self.http_client.as_ref(),
247                OPEN_AI_API_URL,
248                &self.api_key,
249                &output_file_id,
250            )
251            .await
252            .map_err(|e| {
253                anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
254            })?;
255
256            let results = batches::parse_batch_output(&results_content)
257                .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
258
259            let mut updates: Vec<(String, String, String)> = Vec::new();
260            let mut success_count = 0;
261            let mut error_count = 0;
262
263            for result in results {
264                let request_hash = result
265                    .custom_id
266                    .strip_prefix("req_hash_")
267                    .unwrap_or(&result.custom_id)
268                    .to_string();
269
270                if let Some(response_body) = result.response {
271                    if response_body.status_code == 200 {
272                        let response_json = serde_json::to_string(&response_body.body)?;
273                        updates.push((request_hash, response_json, batch_id.clone()));
274                        success_count += 1;
275                    } else {
276                        log::error!(
277                            "Batch request {} failed with status {}",
278                            request_hash,
279                            response_body.status_code
280                        );
281                        let error_json = serde_json::json!({
282                            "error": {
283                                "type": "http_error",
284                                "status_code": response_body.status_code
285                            }
286                        })
287                        .to_string();
288                        updates.push((request_hash, error_json, batch_id.clone()));
289                        error_count += 1;
290                    }
291                } else if let Some(error) = result.error {
292                    log::error!(
293                        "Batch request {} failed: {}: {}",
294                        request_hash,
295                        error.code,
296                        error.message
297                    );
298                    let error_json = serde_json::json!({
299                        "error": {
300                            "type": error.code,
301                            "message": error.message
302                        }
303                    })
304                    .to_string();
305                    updates.push((request_hash, error_json, batch_id.clone()));
306                    error_count += 1;
307                }
308            }
309
310            let connection = self.connection.lock().unwrap();
311            connection.with_savepoint("batch_import", || {
312                let q = sql!(
313                    INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
314                    VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
315                );
316                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
317                for (request_hash, response_json, batch_id) in &updates {
318                    exec((
319                        request_hash.as_str(),
320                        request_hash.as_str(),
321                        response_json.as_str(),
322                        batch_id.as_str(),
323                    ))?;
324                }
325                Ok(())
326            })?;
327
328            log::info!(
329                "Imported batch {}: {} successful, {} errors",
330                batch_id,
331                success_count,
332                error_count
333            );
334        }
335
336        Ok(())
337    }
338
339    async fn download_finished_batches(&self) -> Result<()> {
340        let batch_ids: Vec<String> = {
341            let connection = self.connection.lock().unwrap();
342            let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
343            connection.select(q)?()?
344        };
345
346        for batch_id in &batch_ids {
347            let batch_status = batches::retrieve_batch(
348                self.http_client.as_ref(),
349                OPEN_AI_API_URL,
350                &self.api_key,
351                batch_id,
352            )
353            .await
354            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
355
356            log::info!("Batch {} status: {}", batch_id, batch_status.status);
357
358            if batch_status.status == "completed" {
359                let output_file_id = match batch_status.output_file_id {
360                    Some(id) => id,
361                    None => {
362                        log::warn!("Batch {} completed but has no output file", batch_id);
363                        continue;
364                    }
365                };
366
367                let results_content = batches::download_file(
368                    self.http_client.as_ref(),
369                    OPEN_AI_API_URL,
370                    &self.api_key,
371                    &output_file_id,
372                )
373                .await
374                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
375
376                let results = batches::parse_batch_output(&results_content)
377                    .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
378
379                let mut updates: Vec<(String, String)> = Vec::new();
380                let mut success_count = 0;
381
382                for result in results {
383                    let request_hash = result
384                        .custom_id
385                        .strip_prefix("req_hash_")
386                        .unwrap_or(&result.custom_id)
387                        .to_string();
388
389                    if let Some(response_body) = result.response {
390                        if response_body.status_code == 200 {
391                            let response_json = serde_json::to_string(&response_body.body)?;
392                            updates.push((response_json, request_hash));
393                            success_count += 1;
394                        } else {
395                            log::error!(
396                                "Batch request {} failed with status {}",
397                                request_hash,
398                                response_body.status_code
399                            );
400                            let error_json = serde_json::json!({
401                                "error": {
402                                    "type": "http_error",
403                                    "status_code": response_body.status_code
404                                }
405                            })
406                            .to_string();
407                            updates.push((error_json, request_hash));
408                        }
409                    } else if let Some(error) = result.error {
410                        log::error!(
411                            "Batch request {} failed: {}: {}",
412                            request_hash,
413                            error.code,
414                            error.message
415                        );
416                        let error_json = serde_json::json!({
417                            "error": {
418                                "type": error.code,
419                                "message": error.message
420                            }
421                        })
422                        .to_string();
423                        updates.push((error_json, request_hash));
424                    }
425                }
426
427                let connection = self.connection.lock().unwrap();
428                connection.with_savepoint("batch_download", || {
429                    let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
430                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
431                    for (response_json, request_hash) in &updates {
432                        exec((response_json.as_str(), request_hash.as_str()))?;
433                    }
434                    Ok(())
435                })?;
436                log::info!("Downloaded {} successful requests", success_count);
437            }
438        }
439
440        Ok(())
441    }
442
443    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
444        const BATCH_CHUNK_SIZE: i32 = 16_000;
445        let mut all_batch_ids = Vec::new();
446        let mut total_uploaded = 0;
447
448        loop {
449            let rows: Vec<(String, String)> = {
450                let connection = self.connection.lock().unwrap();
451                let q = sql!(
452                    SELECT request_hash, request FROM openai_cache
453                    WHERE batch_id IS NULL AND response IS NULL
454                    LIMIT ?
455                );
456                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
457            };
458
459            if rows.is_empty() {
460                break;
461            }
462
463            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
464
465            let mut jsonl_content = String::new();
466            for (hash, request_str) in &rows {
467                let serializable_request: SerializableRequest =
468                    serde_json::from_str(request_str).unwrap();
469
470                let messages: Vec<RequestMessage> = serializable_request
471                    .messages
472                    .into_iter()
473                    .map(|msg| match msg.role.as_str() {
474                        "user" => RequestMessage::User {
475                            content: MessageContent::Plain(msg.content),
476                        },
477                        "assistant" => RequestMessage::Assistant {
478                            content: Some(MessageContent::Plain(msg.content)),
479                            tool_calls: Vec::new(),
480                        },
481                        "system" => RequestMessage::System {
482                            content: MessageContent::Plain(msg.content),
483                        },
484                        _ => RequestMessage::User {
485                            content: MessageContent::Plain(msg.content),
486                        },
487                    })
488                    .collect();
489
490                let request = OpenAiRequest {
491                    model: serializable_request.model,
492                    messages,
493                    stream: false,
494                    stream_options: None,
495                    max_completion_tokens: Some(serializable_request.max_tokens),
496                    stop: Vec::new(),
497                    temperature: None,
498                    tool_choice: None,
499                    parallel_tool_calls: None,
500                    tools: Vec::new(),
501                    prompt_cache_key: None,
502                    reasoning_effort: None,
503                };
504
505                let custom_id = format!("req_hash_{}", hash);
506                let batch_item = batches::BatchRequestItem::new(custom_id, request);
507                let line = batch_item
508                    .to_jsonl_line()
509                    .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
510                jsonl_content.push_str(&line);
511                jsonl_content.push('\n');
512            }
513
514            let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
515            let file_obj = batches::upload_batch_file(
516                self.http_client.as_ref(),
517                OPEN_AI_API_URL,
518                &self.api_key,
519                &filename,
520                jsonl_content.into_bytes(),
521            )
522            .await
523            .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
524
525            let batch = batches::create_batch(
526                self.http_client.as_ref(),
527                OPEN_AI_API_URL,
528                &self.api_key,
529                batches::CreateBatchRequest::new(file_obj.id),
530            )
531            .await
532            .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
533
534            {
535                let connection = self.connection.lock().unwrap();
536                connection.with_savepoint("batch_upload", || {
537                    let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
538                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
539                    for hash in &request_hashes {
540                        exec((batch.id.as_str(), hash.as_str()))?;
541                    }
542                    Ok(())
543                })?;
544            }
545
546            let batch_len = rows.len();
547            total_uploaded += batch_len;
548            log::info!(
549                "Uploaded batch {} with {} requests ({} total)",
550                batch.id,
551                batch_len,
552                total_uploaded
553            );
554
555            all_batch_ids.push(batch.id);
556        }
557
558        if !all_batch_ids.is_empty() {
559            log::info!(
560                "Finished uploading {} batches with {} total requests",
561                all_batch_ids.len(),
562                total_uploaded
563            );
564        }
565
566        Ok(all_batch_ids)
567    }
568
569    fn request_hash(
570        model: &str,
571        max_tokens: u64,
572        messages: &[RequestMessage],
573        seed: Option<usize>,
574    ) -> String {
575        let mut hasher = std::hash::DefaultHasher::new();
576        "openai".hash(&mut hasher);
577        model.hash(&mut hasher);
578        max_tokens.hash(&mut hasher);
579        for msg in messages {
580            message_content_to_string(msg).hash(&mut hasher);
581        }
582        if let Some(seed) = seed {
583            seed.hash(&mut hasher);
584        }
585        let request_hash = hasher.finish();
586        format!("{request_hash:016x}")
587    }
588}
589
590fn message_role_to_string(msg: &RequestMessage) -> String {
591    match msg {
592        RequestMessage::User { .. } => "user".to_string(),
593        RequestMessage::Assistant { .. } => "assistant".to_string(),
594        RequestMessage::System { .. } => "system".to_string(),
595        RequestMessage::Tool { .. } => "tool".to_string(),
596    }
597}
598
599fn message_content_to_string(msg: &RequestMessage) -> String {
600    match msg {
601        RequestMessage::User { content } => content_to_string(content),
602        RequestMessage::Assistant { content, .. } => {
603            content.as_ref().map(content_to_string).unwrap_or_default()
604        }
605        RequestMessage::System { content } => content_to_string(content),
606        RequestMessage::Tool { content, .. } => content_to_string(content),
607    }
608}
609
610fn content_to_string(content: &MessageContent) -> String {
611    match content {
612        MessageContent::Plain(text) => text.clone(),
613        MessageContent::Multipart(parts) => parts
614            .iter()
615            .filter_map(|part| match part {
616                open_ai::MessagePart::Text { text } => Some(text.clone()),
617                _ => None,
618            })
619            .collect::<Vec<String>>()
620            .join("\n"),
621    }
622}
623
624pub enum OpenAiClient {
625    Plain(PlainOpenAiClient),
626    Batch(BatchingOpenAiClient),
627    #[allow(dead_code)]
628    Dummy,
629}
630
631impl OpenAiClient {
632    pub fn plain() -> Result<Self> {
633        Ok(Self::Plain(PlainOpenAiClient::new()?))
634    }
635
636    pub fn batch(cache_path: &Path) -> Result<Self> {
637        Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
638    }
639
640    #[allow(dead_code)]
641    pub fn dummy() -> Self {
642        Self::Dummy
643    }
644
645    pub async fn generate(
646        &self,
647        model: &str,
648        max_tokens: u64,
649        messages: Vec<RequestMessage>,
650        seed: Option<usize>,
651        cache_only: bool,
652    ) -> Result<Option<OpenAiResponse>> {
653        match self {
654            OpenAiClient::Plain(plain_client) => plain_client
655                .generate(model, max_tokens, messages)
656                .await
657                .map(Some),
658            OpenAiClient::Batch(batching_client) => {
659                batching_client
660                    .generate(model, max_tokens, messages, seed, cache_only)
661                    .await
662            }
663            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
664        }
665    }
666
667    pub async fn sync_batches(&self) -> Result<()> {
668        match self {
669            OpenAiClient::Plain(_) => Ok(()),
670            OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
671            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
672        }
673    }
674
675    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
676        match self {
677            OpenAiClient::Plain(_) => {
678                anyhow::bail!("Import batches is only supported with batching client")
679            }
680            OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
681            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
682        }
683    }
684}