openai_client.rs

  1use anyhow::Result;
  2use http_client::HttpClient;
  3use indoc::indoc;
  4use open_ai::{
  5    MessageContent, OPEN_AI_API_URL, Request as OpenAiRequest, RequestMessage,
  6    Response as OpenAiResponse, batches, non_streaming_completion,
  7};
  8use reqwest_client::ReqwestClient;
  9use sqlez::bindable::Bind;
 10use sqlez::bindable::StaticColumnCount;
 11use sqlez_macros::sql;
 12use std::hash::Hash;
 13use std::hash::Hasher;
 14use std::path::Path;
 15use std::sync::{Arc, Mutex};
 16
 17pub struct PlainOpenAiClient {
 18    pub http_client: Arc<dyn HttpClient>,
 19    pub api_key: String,
 20}
 21
 22impl PlainOpenAiClient {
 23    pub fn new() -> Result<Self> {
 24        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 25        let api_key = std::env::var("OPENAI_API_KEY")
 26            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
 27        Ok(Self {
 28            http_client,
 29            api_key,
 30        })
 31    }
 32
 33    pub async fn generate(
 34        &self,
 35        model: &str,
 36        max_tokens: u64,
 37        messages: Vec<RequestMessage>,
 38    ) -> Result<OpenAiResponse> {
 39        let request = OpenAiRequest {
 40            model: model.to_string(),
 41            messages,
 42            stream: false,
 43            stream_options: None,
 44            max_completion_tokens: Some(max_tokens),
 45            stop: Vec::new(),
 46            temperature: None,
 47            tool_choice: None,
 48            parallel_tool_calls: None,
 49            tools: Vec::new(),
 50            prompt_cache_key: None,
 51            reasoning_effort: None,
 52        };
 53
 54        let response = non_streaming_completion(
 55            self.http_client.as_ref(),
 56            OPEN_AI_API_URL,
 57            &self.api_key,
 58            request,
 59        )
 60        .await
 61        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 62
 63        Ok(response)
 64    }
 65}
 66
 67pub struct BatchingOpenAiClient {
 68    connection: Mutex<sqlez::connection::Connection>,
 69    http_client: Arc<dyn HttpClient>,
 70    api_key: String,
 71}
 72
 73struct CacheRow {
 74    request_hash: String,
 75    request: Option<String>,
 76    response: Option<String>,
 77    batch_id: Option<String>,
 78}
 79
 80impl StaticColumnCount for CacheRow {
 81    fn column_count() -> usize {
 82        4
 83    }
 84}
 85
 86impl Bind for CacheRow {
 87    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 88        let next_index = statement.bind(&self.request_hash, start_index)?;
 89        let next_index = statement.bind(&self.request, next_index)?;
 90        let next_index = statement.bind(&self.response, next_index)?;
 91        let next_index = statement.bind(&self.batch_id, next_index)?;
 92        Ok(next_index)
 93    }
 94}
 95
 96#[derive(serde::Serialize, serde::Deserialize)]
 97struct SerializableRequest {
 98    model: String,
 99    max_tokens: u64,
100    messages: Vec<SerializableMessage>,
101}
102
103#[derive(serde::Serialize, serde::Deserialize)]
104struct SerializableMessage {
105    role: String,
106    content: String,
107}
108
109impl BatchingOpenAiClient {
110    fn new(cache_path: &Path) -> Result<Self> {
111        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
112        let api_key = std::env::var("OPENAI_API_KEY")
113            .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable not set"))?;
114
115        let connection = sqlez::connection::Connection::open_file(cache_path.to_str().unwrap());
116        let mut statement = sqlez::statement::Statement::prepare(
117            &connection,
118            indoc! {"
119                CREATE TABLE IF NOT EXISTS openai_cache (
120                    request_hash TEXT PRIMARY KEY,
121                    request TEXT,
122                    response TEXT,
123                    batch_id TEXT
124                );
125                "},
126        )?;
127        statement.exec()?;
128        drop(statement);
129
130        Ok(Self {
131            connection: Mutex::new(connection),
132            http_client,
133            api_key,
134        })
135    }
136
137    pub fn lookup(
138        &self,
139        model: &str,
140        max_tokens: u64,
141        messages: &[RequestMessage],
142        seed: Option<usize>,
143    ) -> Result<Option<OpenAiResponse>> {
144        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
145        let connection = self.connection.lock().unwrap();
146        let response: Vec<String> = connection.select_bound(
147            &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
148        )?(request_hash_str.as_str())?;
149        Ok(response
150            .into_iter()
151            .next()
152            .and_then(|text| serde_json::from_str(&text).ok()))
153    }
154
155    pub fn mark_for_batch(
156        &self,
157        model: &str,
158        max_tokens: u64,
159        messages: &[RequestMessage],
160        seed: Option<usize>,
161    ) -> Result<()> {
162        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
163
164        let serializable_messages: Vec<SerializableMessage> = messages
165            .iter()
166            .map(|msg| SerializableMessage {
167                role: message_role_to_string(msg),
168                content: message_content_to_string(msg),
169            })
170            .collect();
171
172        let serializable_request = SerializableRequest {
173            model: model.to_string(),
174            max_tokens,
175            messages: serializable_messages,
176        };
177
178        let request = Some(serde_json::to_string(&serializable_request)?);
179        let cache_row = CacheRow {
180            request_hash,
181            request,
182            response: None,
183            batch_id: None,
184        };
185        let connection = self.connection.lock().unwrap();
186        connection.exec_bound::<CacheRow>(sql!(
187            INSERT OR IGNORE INTO openai_cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
188            cache_row,
189        )
190    }
191
192    async fn generate(
193        &self,
194        model: &str,
195        max_tokens: u64,
196        messages: Vec<RequestMessage>,
197        seed: Option<usize>,
198        cache_only: bool,
199    ) -> Result<Option<OpenAiResponse>> {
200        let response = self.lookup(model, max_tokens, &messages, seed)?;
201        if let Some(response) = response {
202            return Ok(Some(response));
203        }
204
205        if !cache_only {
206            self.mark_for_batch(model, max_tokens, &messages, seed)?;
207        }
208
209        Ok(None)
210    }
211
212    async fn sync_batches(&self) -> Result<()> {
213        let _batch_ids = self.upload_pending_requests().await?;
214        self.download_finished_batches().await
215    }
216
217    pub fn pending_batch_count(&self) -> Result<usize> {
218        let connection = self.connection.lock().unwrap();
219        let counts: Vec<i32> = connection.select(
220            sql!(SELECT COUNT(*) FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL),
221        )?()?;
222        Ok(counts.into_iter().next().unwrap_or(0) as usize)
223    }
224
225    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
226        for batch_id in batch_ids {
227            log::info!("Importing OpenAI batch {}", batch_id);
228
229            let batch_status = batches::retrieve_batch(
230                self.http_client.as_ref(),
231                OPEN_AI_API_URL,
232                &self.api_key,
233                batch_id,
234            )
235            .await
236            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
237
238            log::info!("Batch {} status: {}", batch_id, batch_status.status);
239
240            if batch_status.status != "completed" {
241                log::warn!(
242                    "Batch {} is not completed (status: {}), skipping",
243                    batch_id,
244                    batch_status.status
245                );
246                continue;
247            }
248
249            let output_file_id = batch_status.output_file_id.ok_or_else(|| {
250                anyhow::anyhow!("Batch {} completed but has no output file", batch_id)
251            })?;
252
253            let results_content = batches::download_file(
254                self.http_client.as_ref(),
255                OPEN_AI_API_URL,
256                &self.api_key,
257                &output_file_id,
258            )
259            .await
260            .map_err(|e| {
261                anyhow::anyhow!("Failed to download batch results for {}: {:?}", batch_id, e)
262            })?;
263
264            let results = batches::parse_batch_output(&results_content)
265                .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
266
267            let mut updates: Vec<(String, String, String)> = Vec::new();
268            let mut success_count = 0;
269            let mut error_count = 0;
270
271            for result in results {
272                let request_hash = result
273                    .custom_id
274                    .strip_prefix("req_hash_")
275                    .unwrap_or(&result.custom_id)
276                    .to_string();
277
278                if let Some(response_body) = result.response {
279                    if response_body.status_code == 200 {
280                        let response_json = serde_json::to_string(&response_body.body)?;
281                        updates.push((request_hash, response_json, batch_id.clone()));
282                        success_count += 1;
283                    } else {
284                        log::error!(
285                            "Batch request {} failed with status {}",
286                            request_hash,
287                            response_body.status_code
288                        );
289                        let error_json = serde_json::json!({
290                            "error": {
291                                "type": "http_error",
292                                "status_code": response_body.status_code
293                            }
294                        })
295                        .to_string();
296                        updates.push((request_hash, error_json, batch_id.clone()));
297                        error_count += 1;
298                    }
299                } else if let Some(error) = result.error {
300                    log::error!(
301                        "Batch request {} failed: {}: {}",
302                        request_hash,
303                        error.code,
304                        error.message
305                    );
306                    let error_json = serde_json::json!({
307                        "error": {
308                            "type": error.code,
309                            "message": error.message
310                        }
311                    })
312                    .to_string();
313                    updates.push((request_hash, error_json, batch_id.clone()));
314                    error_count += 1;
315                }
316            }
317
318            let connection = self.connection.lock().unwrap();
319            connection.with_savepoint("batch_import", || {
320                let q = sql!(
321                    INSERT OR REPLACE INTO openai_cache(request_hash, request, response, batch_id)
322                    VALUES (?, (SELECT request FROM openai_cache WHERE request_hash = ?), ?, ?)
323                );
324                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
325                for (request_hash, response_json, batch_id) in &updates {
326                    exec((
327                        request_hash.as_str(),
328                        request_hash.as_str(),
329                        response_json.as_str(),
330                        batch_id.as_str(),
331                    ))?;
332                }
333                Ok(())
334            })?;
335
336            log::info!(
337                "Imported batch {}: {} successful, {} errors",
338                batch_id,
339                success_count,
340                error_count
341            );
342        }
343
344        Ok(())
345    }
346
347    async fn download_finished_batches(&self) -> Result<()> {
348        let batch_ids: Vec<String> = {
349            let connection = self.connection.lock().unwrap();
350            let q = sql!(SELECT DISTINCT batch_id FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL);
351            connection.select(q)?()?
352        };
353
354        for batch_id in &batch_ids {
355            let batch_status = batches::retrieve_batch(
356                self.http_client.as_ref(),
357                OPEN_AI_API_URL,
358                &self.api_key,
359                batch_id,
360            )
361            .await
362            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
363
364            log::info!("Batch {} status: {}", batch_id, batch_status.status);
365
366            if batch_status.status == "completed" {
367                let output_file_id = match batch_status.output_file_id {
368                    Some(id) => id,
369                    None => {
370                        log::warn!("Batch {} completed but has no output file", batch_id);
371                        continue;
372                    }
373                };
374
375                let results_content = batches::download_file(
376                    self.http_client.as_ref(),
377                    OPEN_AI_API_URL,
378                    &self.api_key,
379                    &output_file_id,
380                )
381                .await
382                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
383
384                let results = batches::parse_batch_output(&results_content)
385                    .map_err(|e| anyhow::anyhow!("Failed to parse batch output: {:?}", e))?;
386
387                let mut updates: Vec<(String, String)> = Vec::new();
388                let mut success_count = 0;
389
390                for result in results {
391                    let request_hash = result
392                        .custom_id
393                        .strip_prefix("req_hash_")
394                        .unwrap_or(&result.custom_id)
395                        .to_string();
396
397                    if let Some(response_body) = result.response {
398                        if response_body.status_code == 200 {
399                            let response_json = serde_json::to_string(&response_body.body)?;
400                            updates.push((response_json, request_hash));
401                            success_count += 1;
402                        } else {
403                            log::error!(
404                                "Batch request {} failed with status {}",
405                                request_hash,
406                                response_body.status_code
407                            );
408                            let error_json = serde_json::json!({
409                                "error": {
410                                    "type": "http_error",
411                                    "status_code": response_body.status_code
412                                }
413                            })
414                            .to_string();
415                            updates.push((error_json, request_hash));
416                        }
417                    } else if let Some(error) = result.error {
418                        log::error!(
419                            "Batch request {} failed: {}: {}",
420                            request_hash,
421                            error.code,
422                            error.message
423                        );
424                        let error_json = serde_json::json!({
425                            "error": {
426                                "type": error.code,
427                                "message": error.message
428                            }
429                        })
430                        .to_string();
431                        updates.push((error_json, request_hash));
432                    }
433                }
434
435                let connection = self.connection.lock().unwrap();
436                connection.with_savepoint("batch_download", || {
437                    let q = sql!(UPDATE openai_cache SET response = ? WHERE request_hash = ?);
438                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
439                    for (response_json, request_hash) in &updates {
440                        exec((response_json.as_str(), request_hash.as_str()))?;
441                    }
442                    Ok(())
443                })?;
444                log::info!("Downloaded {} successful requests", success_count);
445            }
446        }
447
448        Ok(())
449    }
450
451    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
452        const BATCH_CHUNK_SIZE: i32 = 16_000;
453        let mut all_batch_ids = Vec::new();
454        let mut total_uploaded = 0;
455
456        loop {
457            let rows: Vec<(String, String)> = {
458                let connection = self.connection.lock().unwrap();
459                let q = sql!(
460                    SELECT request_hash, request FROM openai_cache
461                    WHERE batch_id IS NULL AND response IS NULL
462                    LIMIT ?
463                );
464                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
465            };
466
467            if rows.is_empty() {
468                break;
469            }
470
471            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
472
473            let mut jsonl_content = String::new();
474            for (hash, request_str) in &rows {
475                let serializable_request: SerializableRequest =
476                    serde_json::from_str(request_str).unwrap();
477
478                let messages: Vec<RequestMessage> = serializable_request
479                    .messages
480                    .into_iter()
481                    .map(|msg| match msg.role.as_str() {
482                        "user" => RequestMessage::User {
483                            content: MessageContent::Plain(msg.content),
484                        },
485                        "assistant" => RequestMessage::Assistant {
486                            content: Some(MessageContent::Plain(msg.content)),
487                            tool_calls: Vec::new(),
488                            reasoning_content: None,
489                        },
490                        "system" => RequestMessage::System {
491                            content: MessageContent::Plain(msg.content),
492                        },
493                        _ => RequestMessage::User {
494                            content: MessageContent::Plain(msg.content),
495                        },
496                    })
497                    .collect();
498
499                let request = OpenAiRequest {
500                    model: serializable_request.model,
501                    messages,
502                    stream: false,
503                    stream_options: None,
504                    max_completion_tokens: Some(serializable_request.max_tokens),
505                    stop: Vec::new(),
506                    temperature: None,
507                    tool_choice: None,
508                    parallel_tool_calls: None,
509                    tools: Vec::new(),
510                    prompt_cache_key: None,
511                    reasoning_effort: None,
512                };
513
514                let custom_id = format!("req_hash_{}", hash);
515                let batch_item = batches::BatchRequestItem::new(custom_id, request);
516                let line = batch_item
517                    .to_jsonl_line()
518                    .map_err(|e| anyhow::anyhow!("Failed to serialize batch item: {:?}", e))?;
519                jsonl_content.push_str(&line);
520                jsonl_content.push('\n');
521            }
522
523            let filename = format!("batch_{}.jsonl", chrono::Utc::now().timestamp());
524            let file_obj = batches::upload_batch_file(
525                self.http_client.as_ref(),
526                OPEN_AI_API_URL,
527                &self.api_key,
528                &filename,
529                jsonl_content.into_bytes(),
530            )
531            .await
532            .map_err(|e| anyhow::anyhow!("Failed to upload batch file: {:?}", e))?;
533
534            let batch = batches::create_batch(
535                self.http_client.as_ref(),
536                OPEN_AI_API_URL,
537                &self.api_key,
538                batches::CreateBatchRequest::new(file_obj.id),
539            )
540            .await
541            .map_err(|e| anyhow::anyhow!("Failed to create batch: {:?}", e))?;
542
543            {
544                let connection = self.connection.lock().unwrap();
545                connection.with_savepoint("batch_upload", || {
546                    let q = sql!(UPDATE openai_cache SET batch_id = ? WHERE request_hash = ?);
547                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
548                    for hash in &request_hashes {
549                        exec((batch.id.as_str(), hash.as_str()))?;
550                    }
551                    Ok(())
552                })?;
553            }
554
555            let batch_len = rows.len();
556            total_uploaded += batch_len;
557            log::info!(
558                "Uploaded batch {} with {} requests ({} total)",
559                batch.id,
560                batch_len,
561                total_uploaded
562            );
563
564            all_batch_ids.push(batch.id);
565        }
566
567        if !all_batch_ids.is_empty() {
568            log::info!(
569                "Finished uploading {} batches with {} total requests",
570                all_batch_ids.len(),
571                total_uploaded
572            );
573        }
574
575        Ok(all_batch_ids)
576    }
577
578    fn request_hash(
579        model: &str,
580        max_tokens: u64,
581        messages: &[RequestMessage],
582        seed: Option<usize>,
583    ) -> String {
584        let mut hasher = std::hash::DefaultHasher::new();
585        "openai".hash(&mut hasher);
586        model.hash(&mut hasher);
587        max_tokens.hash(&mut hasher);
588        for msg in messages {
589            message_content_to_string(msg).hash(&mut hasher);
590        }
591        if let Some(seed) = seed {
592            seed.hash(&mut hasher);
593        }
594        let request_hash = hasher.finish();
595        format!("{request_hash:016x}")
596    }
597}
598
599fn message_role_to_string(msg: &RequestMessage) -> String {
600    match msg {
601        RequestMessage::User { .. } => "user".to_string(),
602        RequestMessage::Assistant { .. } => "assistant".to_string(),
603        RequestMessage::System { .. } => "system".to_string(),
604        RequestMessage::Tool { .. } => "tool".to_string(),
605    }
606}
607
608fn message_content_to_string(msg: &RequestMessage) -> String {
609    match msg {
610        RequestMessage::User { content } => content_to_string(content),
611        RequestMessage::Assistant { content, .. } => {
612            content.as_ref().map(content_to_string).unwrap_or_default()
613        }
614        RequestMessage::System { content } => content_to_string(content),
615        RequestMessage::Tool { content, .. } => content_to_string(content),
616    }
617}
618
619fn content_to_string(content: &MessageContent) -> String {
620    match content {
621        MessageContent::Plain(text) => text.clone(),
622        MessageContent::Multipart(parts) => parts
623            .iter()
624            .filter_map(|part| match part {
625                open_ai::MessagePart::Text { text } => Some(text.clone()),
626                _ => None,
627            })
628            .collect::<Vec<String>>()
629            .join("\n"),
630    }
631}
632
633pub enum OpenAiClient {
634    Plain(PlainOpenAiClient),
635    Batch(BatchingOpenAiClient),
636    #[allow(dead_code)]
637    Dummy,
638}
639
640impl OpenAiClient {
641    pub fn plain() -> Result<Self> {
642        Ok(Self::Plain(PlainOpenAiClient::new()?))
643    }
644
645    pub fn batch(cache_path: &Path) -> Result<Self> {
646        Ok(Self::Batch(BatchingOpenAiClient::new(cache_path)?))
647    }
648
649    #[allow(dead_code)]
650    pub fn dummy() -> Self {
651        Self::Dummy
652    }
653
654    pub async fn generate(
655        &self,
656        model: &str,
657        max_tokens: u64,
658        messages: Vec<RequestMessage>,
659        seed: Option<usize>,
660        cache_only: bool,
661    ) -> Result<Option<OpenAiResponse>> {
662        match self {
663            OpenAiClient::Plain(plain_client) => plain_client
664                .generate(model, max_tokens, messages)
665                .await
666                .map(Some),
667            OpenAiClient::Batch(batching_client) => {
668                batching_client
669                    .generate(model, max_tokens, messages, seed, cache_only)
670                    .await
671            }
672            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
673        }
674    }
675
676    pub async fn sync_batches(&self) -> Result<()> {
677        match self {
678            OpenAiClient::Plain(_) => Ok(()),
679            OpenAiClient::Batch(batching_client) => batching_client.sync_batches().await,
680            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
681        }
682    }
683
684    pub fn pending_batch_count(&self) -> Result<usize> {
685        match self {
686            OpenAiClient::Plain(_) => Ok(0),
687            OpenAiClient::Batch(batching_client) => batching_client.pending_batch_count(),
688            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
689        }
690    }
691
692    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
693        match self {
694            OpenAiClient::Plain(_) => {
695                anyhow::bail!("Import batches is only supported with batching client")
696            }
697            OpenAiClient::Batch(batching_client) => batching_client.import_batches(batch_ids).await,
698            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
699        }
700    }
701}