anthropic_client.rs

  1use anthropic::{
  2    ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
  3    Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
  4    stream_completion,
  5};
  6use anyhow::Result;
  7use futures::StreamExt as _;
  8use http_client::HttpClient;
  9use indoc::indoc;
 10use reqwest_client::ReqwestClient;
 11use sqlez::bindable::Bind;
 12use sqlez::bindable::StaticColumnCount;
 13use sqlez_macros::sql;
 14use std::collections::HashSet;
 15use std::hash::Hash;
 16use std::hash::Hasher;
 17use std::path::Path;
 18use std::sync::{Arc, Mutex};
 19
 20pub struct PlainLlmClient {
 21    pub http_client: Arc<dyn HttpClient>,
 22    pub api_key: String,
 23}
 24
 25impl PlainLlmClient {
 26    pub fn new() -> Result<Self> {
 27        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 28        let api_key = std::env::var("ANTHROPIC_API_KEY")
 29            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 30        Ok(Self {
 31            http_client,
 32            api_key,
 33        })
 34    }
 35
 36    pub async fn generate(
 37        &self,
 38        model: &str,
 39        max_tokens: u64,
 40        messages: Vec<Message>,
 41    ) -> Result<AnthropicResponse> {
 42        let request = AnthropicRequest {
 43            model: model.to_string(),
 44            max_tokens,
 45            messages,
 46            tools: Vec::new(),
 47            thinking: None,
 48            tool_choice: None,
 49            system: None,
 50            metadata: None,
 51            output_config: None,
 52            stop_sequences: Vec::new(),
 53            speed: None,
 54            temperature: None,
 55            top_k: None,
 56            top_p: None,
 57        };
 58
 59        let response = non_streaming_completion(
 60            self.http_client.as_ref(),
 61            ANTHROPIC_API_URL,
 62            &self.api_key,
 63            request,
 64            None,
 65        )
 66        .await
 67        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 68
 69        Ok(response)
 70    }
 71
 72    pub async fn generate_streaming<F>(
 73        &self,
 74        model: &str,
 75        max_tokens: u64,
 76        messages: Vec<Message>,
 77        mut on_progress: F,
 78    ) -> Result<AnthropicResponse>
 79    where
 80        F: FnMut(usize, &str),
 81    {
 82        let request = AnthropicRequest {
 83            model: model.to_string(),
 84            max_tokens,
 85            messages,
 86            tools: Vec::new(),
 87            thinking: None,
 88            tool_choice: None,
 89            system: None,
 90            metadata: None,
 91            output_config: None,
 92            stop_sequences: Vec::new(),
 93            speed: None,
 94            temperature: None,
 95            top_k: None,
 96            top_p: None,
 97        };
 98
 99        let mut stream = stream_completion(
100            self.http_client.as_ref(),
101            ANTHROPIC_API_URL,
102            &self.api_key,
103            request,
104            None,
105        )
106        .await
107        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
108
109        let mut response: Option<AnthropicResponse> = None;
110        let mut text_content = String::new();
111
112        while let Some(event_result) = stream.next().await {
113            let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
114
115            match event {
116                Event::MessageStart { message } => {
117                    response = Some(message);
118                }
119                Event::ContentBlockDelta { delta, .. } => {
120                    if let anthropic::ContentDelta::TextDelta { text } = delta {
121                        text_content.push_str(&text);
122                        on_progress(text_content.len(), &text_content);
123                    }
124                }
125                _ => {}
126            }
127        }
128
129        let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
130
131        if response.content.is_empty() && !text_content.is_empty() {
132            response
133                .content
134                .push(ResponseContent::Text { text: text_content });
135        }
136
137        Ok(response)
138    }
139}
140
141pub struct BatchingLlmClient {
142    connection: Mutex<sqlez::connection::Connection>,
143    http_client: Arc<dyn HttpClient>,
144    api_key: String,
145}
146
147struct CacheRow {
148    request_hash: String,
149    request: Option<String>,
150    response: Option<String>,
151    batch_id: Option<String>,
152}
153
154impl StaticColumnCount for CacheRow {
155    fn column_count() -> usize {
156        4
157    }
158}
159
160impl Bind for CacheRow {
161    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
162        let next_index = statement.bind(&self.request_hash, start_index)?;
163        let next_index = statement.bind(&self.request, next_index)?;
164        let next_index = statement.bind(&self.response, next_index)?;
165        let next_index = statement.bind(&self.batch_id, next_index)?;
166        Ok(next_index)
167    }
168}
169
170#[derive(serde::Serialize, serde::Deserialize)]
171struct SerializableRequest {
172    model: String,
173    max_tokens: u64,
174    messages: Vec<SerializableMessage>,
175}
176
177#[derive(serde::Serialize, serde::Deserialize)]
178struct SerializableMessage {
179    role: String,
180    content: String,
181}
182
183impl BatchingLlmClient {
184    fn new(cache_path: &Path) -> Result<Self> {
185        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
186        let api_key = std::env::var("ANTHROPIC_API_KEY")
187            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
188
189        let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
190        let mut statement = sqlez::statement::Statement::prepare(
191            &connection,
192            indoc! {"
193                CREATE TABLE IF NOT EXISTS cache (
194                    request_hash TEXT PRIMARY KEY,
195                    request TEXT,
196                    response TEXT,
197                    batch_id TEXT
198                );
199                "},
200        )?;
201        statement.exec()?;
202        drop(statement);
203
204        Ok(Self {
205            connection: Mutex::new(connection),
206            http_client,
207            api_key,
208        })
209    }
210
211    pub fn lookup(
212        &self,
213        model: &str,
214        max_tokens: u64,
215        messages: &[Message],
216        seed: Option<usize>,
217    ) -> Result<Option<AnthropicResponse>> {
218        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
219        let connection = self.connection.lock().unwrap();
220        let response: Vec<String> = connection.select_bound(
221            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
222        )?(request_hash_str.as_str())?;
223        Ok(response
224            .into_iter()
225            .next()
226            .and_then(|text| serde_json::from_str(&text).ok()))
227    }
228
229    pub fn mark_for_batch(
230        &self,
231        model: &str,
232        max_tokens: u64,
233        messages: &[Message],
234        seed: Option<usize>,
235    ) -> Result<()> {
236        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
237
238        let serializable_messages: Vec<SerializableMessage> = messages
239            .iter()
240            .map(|msg| SerializableMessage {
241                role: match msg.role {
242                    Role::User => "user".to_string(),
243                    Role::Assistant => "assistant".to_string(),
244                },
245                content: message_content_to_string(&msg.content),
246            })
247            .collect();
248
249        let serializable_request = SerializableRequest {
250            model: model.to_string(),
251            max_tokens,
252            messages: serializable_messages,
253        };
254
255        let request = Some(serde_json::to_string(&serializable_request)?);
256        let cache_row = CacheRow {
257            request_hash,
258            request,
259            response: None,
260            batch_id: None,
261        };
262        let connection = self.connection.lock().unwrap();
263        connection.exec_bound::<CacheRow>(sql!(
264            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
265            cache_row,
266        )
267    }
268
269    async fn generate(
270        &self,
271        model: &str,
272        max_tokens: u64,
273        messages: Vec<Message>,
274        seed: Option<usize>,
275        cache_only: bool,
276    ) -> Result<Option<AnthropicResponse>> {
277        let response = self.lookup(model, max_tokens, &messages, seed)?;
278        if let Some(response) = response {
279            return Ok(Some(response));
280        }
281
282        if !cache_only {
283            self.mark_for_batch(model, max_tokens, &messages, seed)?;
284        }
285
286        Ok(None)
287    }
288
289    /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
290    async fn sync_batches(&self) -> Result<()> {
291        let _batch_ids = self.upload_pending_requests().await?;
292        self.download_finished_batches().await
293    }
294
295    pub fn pending_batch_count(&self) -> Result<usize> {
296        let connection = self.connection.lock().unwrap();
297        let counts: Vec<i32> = connection.select(
298            sql!(SELECT COUNT(*) FROM cache WHERE batch_id IS NOT NULL AND response IS NULL),
299        )?()?;
300        Ok(counts.into_iter().next().unwrap_or(0) as usize)
301    }
302
303    /// Import batch results from external batch IDs (useful for recovering after database loss)
304    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
305        for batch_id in batch_ids {
306            log::info!("Importing batch {}", batch_id);
307
308            let batch_status = anthropic::batches::retrieve_batch(
309                self.http_client.as_ref(),
310                ANTHROPIC_API_URL,
311                &self.api_key,
312                batch_id,
313            )
314            .await
315            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
316
317            log::info!(
318                "Batch {} status: {}",
319                batch_id,
320                batch_status.processing_status
321            );
322
323            if batch_status.processing_status != "ended" {
324                log::warn!(
325                    "Batch {} is not finished (status: {}), skipping",
326                    batch_id,
327                    batch_status.processing_status
328                );
329                continue;
330            }
331
332            let results = anthropic::batches::retrieve_batch_results(
333                self.http_client.as_ref(),
334                ANTHROPIC_API_URL,
335                &self.api_key,
336                batch_id,
337            )
338            .await
339            .map_err(|e| {
340                anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
341            })?;
342
343            let mut updates: Vec<(String, String, String)> = Vec::new();
344            let mut success_count = 0;
345            let mut error_count = 0;
346
347            for result in results {
348                let request_hash = result
349                    .custom_id
350                    .strip_prefix("req_hash_")
351                    .unwrap_or(&result.custom_id)
352                    .to_string();
353
354                match result.result {
355                    anthropic::batches::BatchResult::Succeeded { message } => {
356                        let response_json = serde_json::to_string(&message)?;
357                        updates.push((request_hash, response_json, batch_id.clone()));
358                        success_count += 1;
359                    }
360                    anthropic::batches::BatchResult::Errored { error } => {
361                        log::error!(
362                            "Batch request {} failed: {}: {}",
363                            request_hash,
364                            error.error.error_type,
365                            error.error.message
366                        );
367                        let error_json = serde_json::json!({
368                            "error": {
369                                "type": error.error.error_type,
370                                "message": error.error.message
371                            }
372                        })
373                        .to_string();
374                        updates.push((request_hash, error_json, batch_id.clone()));
375                        error_count += 1;
376                    }
377                    anthropic::batches::BatchResult::Canceled => {
378                        log::warn!("Batch request {} was canceled", request_hash);
379                        error_count += 1;
380                    }
381                    anthropic::batches::BatchResult::Expired => {
382                        log::warn!("Batch request {} expired", request_hash);
383                        error_count += 1;
384                    }
385                }
386            }
387
388            let connection = self.connection.lock().unwrap();
389            connection.with_savepoint("batch_import", || {
390                // Use INSERT OR REPLACE to handle both new entries and updating existing ones
391                let q = sql!(
392                    INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
393                    VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
394                );
395                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
396                for (request_hash, response_json, batch_id) in &updates {
397                    exec((
398                        request_hash.as_str(),
399                        request_hash.as_str(),
400                        response_json.as_str(),
401                        batch_id.as_str(),
402                    ))?;
403                }
404                Ok(())
405            })?;
406
407            log::info!(
408                "Imported batch {}: {} successful, {} errors",
409                batch_id,
410                success_count,
411                error_count
412            );
413        }
414
415        Ok(())
416    }
417
418    async fn download_finished_batches(&self) -> Result<()> {
419        let batch_ids: Vec<String> = {
420            let connection = self.connection.lock().unwrap();
421            let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
422            connection.select(q)?()?
423        };
424
425        for batch_id in &batch_ids {
426            let batch_status = anthropic::batches::retrieve_batch(
427                self.http_client.as_ref(),
428                ANTHROPIC_API_URL,
429                &self.api_key,
430                &batch_id,
431            )
432            .await
433            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
434
435            log::info!(
436                "Batch {} status: {}",
437                batch_id,
438                batch_status.processing_status
439            );
440
441            if batch_status.processing_status == "ended" {
442                let results = anthropic::batches::retrieve_batch_results(
443                    self.http_client.as_ref(),
444                    ANTHROPIC_API_URL,
445                    &self.api_key,
446                    &batch_id,
447                )
448                .await
449                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
450
451                let mut updates: Vec<(String, String)> = Vec::new();
452                let mut success_count = 0;
453                for result in results {
454                    let request_hash = result
455                        .custom_id
456                        .strip_prefix("req_hash_")
457                        .unwrap_or(&result.custom_id)
458                        .to_string();
459
460                    match result.result {
461                        anthropic::batches::BatchResult::Succeeded { message } => {
462                            let response_json = serde_json::to_string(&message)?;
463                            updates.push((response_json, request_hash));
464                            success_count += 1;
465                        }
466                        anthropic::batches::BatchResult::Errored { error } => {
467                            log::error!(
468                                "Batch request {} failed: {}: {}",
469                                request_hash,
470                                error.error.error_type,
471                                error.error.message
472                            );
473                            let error_json = serde_json::json!({
474                                "error": {
475                                    "type": error.error.error_type,
476                                    "message": error.error.message
477                                }
478                            })
479                            .to_string();
480                            updates.push((error_json, request_hash));
481                        }
482                        anthropic::batches::BatchResult::Canceled => {
483                            log::warn!("Batch request {} was canceled", request_hash);
484                            let error_json = serde_json::json!({
485                                "error": {
486                                    "type": "canceled",
487                                    "message": "Batch request was canceled"
488                                }
489                            })
490                            .to_string();
491                            updates.push((error_json, request_hash));
492                        }
493                        anthropic::batches::BatchResult::Expired => {
494                            log::warn!("Batch request {} expired", request_hash);
495                            let error_json = serde_json::json!({
496                                "error": {
497                                    "type": "expired",
498                                    "message": "Batch request expired"
499                                }
500                            })
501                            .to_string();
502                            updates.push((error_json, request_hash));
503                        }
504                    }
505                }
506
507                let connection = self.connection.lock().unwrap();
508                connection.with_savepoint("batch_download", || {
509                    let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
510                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
511                    for (response_json, request_hash) in &updates {
512                        exec((response_json.as_str(), request_hash.as_str()))?;
513                    }
514                    Ok(())
515                })?;
516                log::info!("Downloaded {} successful requests", success_count);
517            }
518        }
519
520        Ok(())
521    }
522
523    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
524        const BATCH_CHUNK_SIZE: i32 = 16_000;
525        const MAX_BATCH_SIZE_BYTES: usize = 100 * 1024 * 1024;
526
527        let mut all_batch_ids = Vec::new();
528        let mut total_uploaded = 0;
529
530        let mut current_batch_rows = Vec::new();
531        let mut current_batch_size = 0usize;
532        let mut pending_hashes: HashSet<String> = HashSet::new();
533        loop {
534            let rows: Vec<(String, String)> = {
535                let connection = self.connection.lock().unwrap();
536                let q = sql!(
537                    SELECT request_hash, request FROM cache
538                    WHERE batch_id IS NULL AND response IS NULL
539                    LIMIT ?
540                );
541                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
542            };
543
544            if rows.is_empty() {
545                break;
546            }
547
548            // Split rows into sub-batches based on size
549            let mut batches_to_upload = Vec::new();
550            let mut new_rows_added = 0;
551
552            for row in rows {
553                let (hash, request_str) = row;
554
555                // Skip rows already added to current_batch_rows but not yet uploaded
556                if pending_hashes.contains(&hash) {
557                    continue;
558                }
559                let serializable_request: SerializableRequest = serde_json::from_str(&request_str)?;
560
561                let messages: Vec<Message> = serializable_request
562                    .messages
563                    .into_iter()
564                    .map(|msg| Message {
565                        role: match msg.role.as_str() {
566                            "user" => Role::User,
567                            "assistant" => Role::Assistant,
568                            _ => Role::User,
569                        },
570                        content: vec![RequestContent::Text {
571                            text: msg.content,
572                            cache_control: None,
573                        }],
574                    })
575                    .collect();
576
577                let params = AnthropicRequest {
578                    model: serializable_request.model,
579                    max_tokens: serializable_request.max_tokens,
580                    messages,
581                    tools: Vec::new(),
582                    thinking: None,
583                    tool_choice: None,
584                    system: None,
585                    metadata: None,
586                    output_config: None,
587                    stop_sequences: Vec::new(),
588                    temperature: None,
589                    top_k: None,
590                    top_p: None,
591                    speed: None,
592                };
593
594                let custom_id = format!("req_hash_{}", hash);
595                let batch_request = anthropic::batches::BatchRequest { custom_id, params };
596
597                // Estimate the serialized size of this request
598                let estimated_size = serde_json::to_string(&batch_request)?.len();
599
600                // If adding this request would exceed the limit, start a new batch
601                if !current_batch_rows.is_empty()
602                    && current_batch_size + estimated_size > MAX_BATCH_SIZE_BYTES
603                {
604                    batches_to_upload.push((current_batch_rows, current_batch_size));
605                    current_batch_rows = Vec::new();
606                    current_batch_size = 0;
607                }
608
609                pending_hashes.insert(hash.clone());
610                current_batch_rows.push((hash, batch_request));
611                current_batch_size += estimated_size;
612                new_rows_added += 1;
613            }
614
615            // If no new rows were added this iteration, all pending requests are already
616            // in current_batch_rows, so we should break to avoid an infinite loop
617            if new_rows_added == 0 {
618                break;
619            }
620
621            // Only upload full batches, keep the partial batch for the next iteration
622            // Upload each sub-batch
623            for (batch_rows, batch_size) in batches_to_upload {
624                let request_hashes: Vec<String> =
625                    batch_rows.iter().map(|(hash, _)| hash.clone()).collect();
626
627                // Remove uploaded hashes from pending set
628                for hash in &request_hashes {
629                    pending_hashes.remove(hash);
630                }
631                let batch_requests: Vec<anthropic::batches::BatchRequest> =
632                    batch_rows.into_iter().map(|(_, req)| req).collect();
633
634                let batch_len = batch_requests.len();
635                log::info!(
636                    "Uploading batch with {} requests (~{:.2} MB)",
637                    batch_len,
638                    batch_size as f64 / (1024.0 * 1024.0)
639                );
640
641                let batch = anthropic::batches::create_batch(
642                    self.http_client.as_ref(),
643                    ANTHROPIC_API_URL,
644                    &self.api_key,
645                    anthropic::batches::CreateBatchRequest {
646                        requests: batch_requests,
647                    },
648                )
649                .await
650                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
651
652                {
653                    let connection = self.connection.lock().unwrap();
654                    connection.with_savepoint("batch_upload", || {
655                        let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
656                        let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
657                        for hash in &request_hashes {
658                            exec((batch.id.as_str(), hash.as_str()))?;
659                        }
660                        Ok(())
661                    })?;
662                }
663
664                total_uploaded += batch_len;
665                log::info!(
666                    "Uploaded batch {} with {} requests ({} total)",
667                    batch.id,
668                    batch_len,
669                    total_uploaded
670                );
671
672                all_batch_ids.push(batch.id);
673            }
674        }
675
676        // Upload any remaining partial batch at the end
677        if !current_batch_rows.is_empty() {
678            let request_hashes: Vec<String> = current_batch_rows
679                .iter()
680                .map(|(hash, _)| hash.clone())
681                .collect();
682            let batch_requests: Vec<anthropic::batches::BatchRequest> =
683                current_batch_rows.into_iter().map(|(_, req)| req).collect();
684
685            let batch_len = batch_requests.len();
686            log::info!(
687                "Uploading final batch with {} requests (~{:.2} MB)",
688                batch_len,
689                current_batch_size as f64 / (1024.0 * 1024.0)
690            );
691
692            let batch = anthropic::batches::create_batch(
693                self.http_client.as_ref(),
694                ANTHROPIC_API_URL,
695                &self.api_key,
696                anthropic::batches::CreateBatchRequest {
697                    requests: batch_requests,
698                },
699            )
700            .await
701            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
702
703            {
704                let connection = self.connection.lock().unwrap();
705                connection.with_savepoint("batch_upload", || {
706                    let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
707                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
708                    for hash in &request_hashes {
709                        exec((batch.id.as_str(), hash.as_str()))?;
710                    }
711                    Ok(())
712                })?;
713            }
714
715            total_uploaded += batch_len;
716            log::info!(
717                "Uploaded batch {} with {} requests ({} total)",
718                batch.id,
719                batch_len,
720                total_uploaded
721            );
722
723            all_batch_ids.push(batch.id);
724        }
725
726        if !all_batch_ids.is_empty() {
727            log::info!(
728                "Finished uploading {} batches with {} total requests",
729                all_batch_ids.len(),
730                total_uploaded
731            );
732        }
733
734        Ok(all_batch_ids)
735    }
736
737    fn request_hash(
738        model: &str,
739        max_tokens: u64,
740        messages: &[Message],
741        seed: Option<usize>,
742    ) -> String {
743        let mut hasher = std::hash::DefaultHasher::new();
744        model.hash(&mut hasher);
745        max_tokens.hash(&mut hasher);
746        for msg in messages {
747            message_content_to_string(&msg.content).hash(&mut hasher);
748        }
749        if let Some(seed) = seed {
750            seed.hash(&mut hasher);
751        }
752        let request_hash = hasher.finish();
753        format!("{request_hash:016x}")
754    }
755}
756
757fn message_content_to_string(content: &[RequestContent]) -> String {
758    content
759        .iter()
760        .filter_map(|c| match c {
761            RequestContent::Text { text, .. } => Some(text.clone()),
762            _ => None,
763        })
764        .collect::<Vec<String>>()
765        .join("\n")
766}
767
768pub enum AnthropicClient {
769    // No batching
770    Plain(PlainLlmClient),
771    Batch(BatchingLlmClient),
772    Dummy,
773}
774
775impl AnthropicClient {
776    pub fn plain() -> Result<Self> {
777        Ok(Self::Plain(PlainLlmClient::new()?))
778    }
779
780    pub fn batch(cache_path: &Path) -> Result<Self> {
781        Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
782    }
783
784    #[allow(dead_code)]
785    pub fn dummy() -> Self {
786        Self::Dummy
787    }
788
789    pub async fn generate(
790        &self,
791        model: &str,
792        max_tokens: u64,
793        messages: Vec<Message>,
794        seed: Option<usize>,
795        cache_only: bool,
796    ) -> Result<Option<AnthropicResponse>> {
797        match self {
798            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
799                .generate(model, max_tokens, messages)
800                .await
801                .map(Some),
802            AnthropicClient::Batch(batching_llm_client) => {
803                batching_llm_client
804                    .generate(model, max_tokens, messages, seed, cache_only)
805                    .await
806            }
807            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
808        }
809    }
810
811    #[allow(dead_code)]
812    pub async fn generate_streaming<F>(
813        &self,
814        model: &str,
815        max_tokens: u64,
816        messages: Vec<Message>,
817        on_progress: F,
818    ) -> Result<Option<AnthropicResponse>>
819    where
820        F: FnMut(usize, &str),
821    {
822        match self {
823            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
824                .generate_streaming(model, max_tokens, messages, on_progress)
825                .await
826                .map(Some),
827            AnthropicClient::Batch(_) => {
828                anyhow::bail!("Streaming not supported with batching client")
829            }
830            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
831        }
832    }
833
834    pub async fn sync_batches(&self) -> Result<()> {
835        match self {
836            AnthropicClient::Plain(_) => Ok(()),
837            AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
838            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
839        }
840    }
841
842    pub fn pending_batch_count(&self) -> Result<usize> {
843        match self {
844            AnthropicClient::Plain(_) => Ok(0),
845            AnthropicClient::Batch(batching_llm_client) => {
846                batching_llm_client.pending_batch_count()
847            }
848            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
849        }
850    }
851
852    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
853        match self {
854            AnthropicClient::Plain(_) => {
855                anyhow::bail!("Import batches is only supported with batching client")
856            }
857            AnthropicClient::Batch(batching_llm_client) => {
858                batching_llm_client.import_batches(batch_ids).await
859            }
860            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
861        }
862    }
863}