anthropic_client.rs

  1use anthropic::{
  2    ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
  3    Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
  4    stream_completion,
  5};
  6use anyhow::Result;
  7use futures::StreamExt as _;
  8use http_client::HttpClient;
  9use indoc::indoc;
 10use reqwest_client::ReqwestClient;
 11use sqlez::bindable::Bind;
 12use sqlez::bindable::StaticColumnCount;
 13use sqlez_macros::sql;
 14use std::collections::HashSet;
 15use std::hash::Hash;
 16use std::hash::Hasher;
 17use std::path::Path;
 18use std::sync::{Arc, Mutex};
 19
 20pub struct PlainLlmClient {
 21    pub http_client: Arc<dyn HttpClient>,
 22    pub api_key: String,
 23}
 24
 25impl PlainLlmClient {
 26    pub fn new() -> Result<Self> {
 27        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 28        let api_key = std::env::var("ANTHROPIC_API_KEY")
 29            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 30        Ok(Self {
 31            http_client,
 32            api_key,
 33        })
 34    }
 35
 36    pub async fn generate(
 37        &self,
 38        model: &str,
 39        max_tokens: u64,
 40        messages: Vec<Message>,
 41    ) -> Result<AnthropicResponse> {
 42        let request = AnthropicRequest {
 43            model: model.to_string(),
 44            max_tokens,
 45            messages,
 46            tools: Vec::new(),
 47            thinking: None,
 48            tool_choice: None,
 49            system: None,
 50            metadata: None,
 51            output_config: None,
 52            stop_sequences: Vec::new(),
 53            temperature: None,
 54            top_k: None,
 55            top_p: None,
 56        };
 57
 58        let response = non_streaming_completion(
 59            self.http_client.as_ref(),
 60            ANTHROPIC_API_URL,
 61            &self.api_key,
 62            request,
 63            None,
 64        )
 65        .await
 66        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 67
 68        Ok(response)
 69    }
 70
 71    pub async fn generate_streaming<F>(
 72        &self,
 73        model: &str,
 74        max_tokens: u64,
 75        messages: Vec<Message>,
 76        mut on_progress: F,
 77    ) -> Result<AnthropicResponse>
 78    where
 79        F: FnMut(usize, &str),
 80    {
 81        let request = AnthropicRequest {
 82            model: model.to_string(),
 83            max_tokens,
 84            messages,
 85            tools: Vec::new(),
 86            thinking: None,
 87            tool_choice: None,
 88            system: None,
 89            metadata: None,
 90            output_config: None,
 91            stop_sequences: Vec::new(),
 92            temperature: None,
 93            top_k: None,
 94            top_p: None,
 95        };
 96
 97        let mut stream = stream_completion(
 98            self.http_client.as_ref(),
 99            ANTHROPIC_API_URL,
100            &self.api_key,
101            request,
102            None,
103        )
104        .await
105        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
106
107        let mut response: Option<AnthropicResponse> = None;
108        let mut text_content = String::new();
109
110        while let Some(event_result) = stream.next().await {
111            let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
112
113            match event {
114                Event::MessageStart { message } => {
115                    response = Some(message);
116                }
117                Event::ContentBlockDelta { delta, .. } => {
118                    if let anthropic::ContentDelta::TextDelta { text } = delta {
119                        text_content.push_str(&text);
120                        on_progress(text_content.len(), &text_content);
121                    }
122                }
123                _ => {}
124            }
125        }
126
127        let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
128
129        if response.content.is_empty() && !text_content.is_empty() {
130            response
131                .content
132                .push(ResponseContent::Text { text: text_content });
133        }
134
135        Ok(response)
136    }
137}
138
139pub struct BatchingLlmClient {
140    connection: Mutex<sqlez::connection::Connection>,
141    http_client: Arc<dyn HttpClient>,
142    api_key: String,
143}
144
145struct CacheRow {
146    request_hash: String,
147    request: Option<String>,
148    response: Option<String>,
149    batch_id: Option<String>,
150}
151
152impl StaticColumnCount for CacheRow {
153    fn column_count() -> usize {
154        4
155    }
156}
157
158impl Bind for CacheRow {
159    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
160        let next_index = statement.bind(&self.request_hash, start_index)?;
161        let next_index = statement.bind(&self.request, next_index)?;
162        let next_index = statement.bind(&self.response, next_index)?;
163        let next_index = statement.bind(&self.batch_id, next_index)?;
164        Ok(next_index)
165    }
166}
167
168#[derive(serde::Serialize, serde::Deserialize)]
169struct SerializableRequest {
170    model: String,
171    max_tokens: u64,
172    messages: Vec<SerializableMessage>,
173}
174
175#[derive(serde::Serialize, serde::Deserialize)]
176struct SerializableMessage {
177    role: String,
178    content: String,
179}
180
181impl BatchingLlmClient {
182    fn new(cache_path: &Path) -> Result<Self> {
183        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
184        let api_key = std::env::var("ANTHROPIC_API_KEY")
185            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
186
187        let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
188        let mut statement = sqlez::statement::Statement::prepare(
189            &connection,
190            indoc! {"
191                CREATE TABLE IF NOT EXISTS cache (
192                    request_hash TEXT PRIMARY KEY,
193                    request TEXT,
194                    response TEXT,
195                    batch_id TEXT
196                );
197                "},
198        )?;
199        statement.exec()?;
200        drop(statement);
201
202        Ok(Self {
203            connection: Mutex::new(connection),
204            http_client,
205            api_key,
206        })
207    }
208
209    pub fn lookup(
210        &self,
211        model: &str,
212        max_tokens: u64,
213        messages: &[Message],
214        seed: Option<usize>,
215    ) -> Result<Option<AnthropicResponse>> {
216        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
217        let connection = self.connection.lock().unwrap();
218        let response: Vec<String> = connection.select_bound(
219            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
220        )?(request_hash_str.as_str())?;
221        Ok(response
222            .into_iter()
223            .next()
224            .and_then(|text| serde_json::from_str(&text).ok()))
225    }
226
227    pub fn mark_for_batch(
228        &self,
229        model: &str,
230        max_tokens: u64,
231        messages: &[Message],
232        seed: Option<usize>,
233    ) -> Result<()> {
234        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
235
236        let serializable_messages: Vec<SerializableMessage> = messages
237            .iter()
238            .map(|msg| SerializableMessage {
239                role: match msg.role {
240                    Role::User => "user".to_string(),
241                    Role::Assistant => "assistant".to_string(),
242                },
243                content: message_content_to_string(&msg.content),
244            })
245            .collect();
246
247        let serializable_request = SerializableRequest {
248            model: model.to_string(),
249            max_tokens,
250            messages: serializable_messages,
251        };
252
253        let request = Some(serde_json::to_string(&serializable_request)?);
254        let cache_row = CacheRow {
255            request_hash,
256            request,
257            response: None,
258            batch_id: None,
259        };
260        let connection = self.connection.lock().unwrap();
261        connection.exec_bound::<CacheRow>(sql!(
262            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
263            cache_row,
264        )
265    }
266
267    async fn generate(
268        &self,
269        model: &str,
270        max_tokens: u64,
271        messages: Vec<Message>,
272        seed: Option<usize>,
273        cache_only: bool,
274    ) -> Result<Option<AnthropicResponse>> {
275        let response = self.lookup(model, max_tokens, &messages, seed)?;
276        if let Some(response) = response {
277            return Ok(Some(response));
278        }
279
280        if !cache_only {
281            self.mark_for_batch(model, max_tokens, &messages, seed)?;
282        }
283
284        Ok(None)
285    }
286
287    /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
288    async fn sync_batches(&self) -> Result<()> {
289        let _batch_ids = self.upload_pending_requests().await?;
290        self.download_finished_batches().await
291    }
292
293    /// Import batch results from external batch IDs (useful for recovering after database loss)
294    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
295        for batch_id in batch_ids {
296            log::info!("Importing batch {}", batch_id);
297
298            let batch_status = anthropic::batches::retrieve_batch(
299                self.http_client.as_ref(),
300                ANTHROPIC_API_URL,
301                &self.api_key,
302                batch_id,
303            )
304            .await
305            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
306
307            log::info!(
308                "Batch {} status: {}",
309                batch_id,
310                batch_status.processing_status
311            );
312
313            if batch_status.processing_status != "ended" {
314                log::warn!(
315                    "Batch {} is not finished (status: {}), skipping",
316                    batch_id,
317                    batch_status.processing_status
318                );
319                continue;
320            }
321
322            let results = anthropic::batches::retrieve_batch_results(
323                self.http_client.as_ref(),
324                ANTHROPIC_API_URL,
325                &self.api_key,
326                batch_id,
327            )
328            .await
329            .map_err(|e| {
330                anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
331            })?;
332
333            let mut updates: Vec<(String, String, String)> = Vec::new();
334            let mut success_count = 0;
335            let mut error_count = 0;
336
337            for result in results {
338                let request_hash = result
339                    .custom_id
340                    .strip_prefix("req_hash_")
341                    .unwrap_or(&result.custom_id)
342                    .to_string();
343
344                match result.result {
345                    anthropic::batches::BatchResult::Succeeded { message } => {
346                        let response_json = serde_json::to_string(&message)?;
347                        updates.push((request_hash, response_json, batch_id.clone()));
348                        success_count += 1;
349                    }
350                    anthropic::batches::BatchResult::Errored { error } => {
351                        log::error!(
352                            "Batch request {} failed: {}: {}",
353                            request_hash,
354                            error.error.error_type,
355                            error.error.message
356                        );
357                        let error_json = serde_json::json!({
358                            "error": {
359                                "type": error.error.error_type,
360                                "message": error.error.message
361                            }
362                        })
363                        .to_string();
364                        updates.push((request_hash, error_json, batch_id.clone()));
365                        error_count += 1;
366                    }
367                    anthropic::batches::BatchResult::Canceled => {
368                        log::warn!("Batch request {} was canceled", request_hash);
369                        error_count += 1;
370                    }
371                    anthropic::batches::BatchResult::Expired => {
372                        log::warn!("Batch request {} expired", request_hash);
373                        error_count += 1;
374                    }
375                }
376            }
377
378            let connection = self.connection.lock().unwrap();
379            connection.with_savepoint("batch_import", || {
380                // Use INSERT OR REPLACE to handle both new entries and updating existing ones
381                let q = sql!(
382                    INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
383                    VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
384                );
385                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
386                for (request_hash, response_json, batch_id) in &updates {
387                    exec((
388                        request_hash.as_str(),
389                        request_hash.as_str(),
390                        response_json.as_str(),
391                        batch_id.as_str(),
392                    ))?;
393                }
394                Ok(())
395            })?;
396
397            log::info!(
398                "Imported batch {}: {} successful, {} errors",
399                batch_id,
400                success_count,
401                error_count
402            );
403        }
404
405        Ok(())
406    }
407
408    async fn download_finished_batches(&self) -> Result<()> {
409        let batch_ids: Vec<String> = {
410            let connection = self.connection.lock().unwrap();
411            let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
412            connection.select(q)?()?
413        };
414
415        for batch_id in &batch_ids {
416            let batch_status = anthropic::batches::retrieve_batch(
417                self.http_client.as_ref(),
418                ANTHROPIC_API_URL,
419                &self.api_key,
420                &batch_id,
421            )
422            .await
423            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
424
425            log::info!(
426                "Batch {} status: {}",
427                batch_id,
428                batch_status.processing_status
429            );
430
431            if batch_status.processing_status == "ended" {
432                let results = anthropic::batches::retrieve_batch_results(
433                    self.http_client.as_ref(),
434                    ANTHROPIC_API_URL,
435                    &self.api_key,
436                    &batch_id,
437                )
438                .await
439                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
440
441                let mut updates: Vec<(String, String)> = Vec::new();
442                let mut success_count = 0;
443                for result in results {
444                    let request_hash = result
445                        .custom_id
446                        .strip_prefix("req_hash_")
447                        .unwrap_or(&result.custom_id)
448                        .to_string();
449
450                    match result.result {
451                        anthropic::batches::BatchResult::Succeeded { message } => {
452                            let response_json = serde_json::to_string(&message)?;
453                            updates.push((response_json, request_hash));
454                            success_count += 1;
455                        }
456                        anthropic::batches::BatchResult::Errored { error } => {
457                            log::error!(
458                                "Batch request {} failed: {}: {}",
459                                request_hash,
460                                error.error.error_type,
461                                error.error.message
462                            );
463                            let error_json = serde_json::json!({
464                                "error": {
465                                    "type": error.error.error_type,
466                                    "message": error.error.message
467                                }
468                            })
469                            .to_string();
470                            updates.push((error_json, request_hash));
471                        }
472                        anthropic::batches::BatchResult::Canceled => {
473                            log::warn!("Batch request {} was canceled", request_hash);
474                            let error_json = serde_json::json!({
475                                "error": {
476                                    "type": "canceled",
477                                    "message": "Batch request was canceled"
478                                }
479                            })
480                            .to_string();
481                            updates.push((error_json, request_hash));
482                        }
483                        anthropic::batches::BatchResult::Expired => {
484                            log::warn!("Batch request {} expired", request_hash);
485                            let error_json = serde_json::json!({
486                                "error": {
487                                    "type": "expired",
488                                    "message": "Batch request expired"
489                                }
490                            })
491                            .to_string();
492                            updates.push((error_json, request_hash));
493                        }
494                    }
495                }
496
497                let connection = self.connection.lock().unwrap();
498                connection.with_savepoint("batch_download", || {
499                    let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
500                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
501                    for (response_json, request_hash) in &updates {
502                        exec((response_json.as_str(), request_hash.as_str()))?;
503                    }
504                    Ok(())
505                })?;
506                log::info!("Downloaded {} successful requests", success_count);
507            }
508        }
509
510        Ok(())
511    }
512
513    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
514        const BATCH_CHUNK_SIZE: i32 = 16_000;
515        const MAX_BATCH_SIZE_BYTES: usize = 200 * 1024 * 1024; // 200MB (buffer below 256MB limit)
516        let mut all_batch_ids = Vec::new();
517        let mut total_uploaded = 0;
518
519        let mut current_batch_rows = Vec::new();
520        let mut current_batch_size = 0usize;
521        let mut pending_hashes: HashSet<String> = HashSet::new();
522        loop {
523            let rows: Vec<(String, String)> = {
524                let connection = self.connection.lock().unwrap();
525                let q = sql!(
526                    SELECT request_hash, request FROM cache
527                    WHERE batch_id IS NULL AND response IS NULL
528                    LIMIT ?
529                );
530                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
531            };
532
533            if rows.is_empty() {
534                break;
535            }
536
537            // Split rows into sub-batches based on size
538            let mut batches_to_upload = Vec::new();
539            let mut new_rows_added = 0;
540
541            for row in rows {
542                let (hash, request_str) = row;
543
544                // Skip rows already added to current_batch_rows but not yet uploaded
545                if pending_hashes.contains(&hash) {
546                    continue;
547                }
548                let serializable_request: SerializableRequest = serde_json::from_str(&request_str)?;
549
550                let messages: Vec<Message> = serializable_request
551                    .messages
552                    .into_iter()
553                    .map(|msg| Message {
554                        role: match msg.role.as_str() {
555                            "user" => Role::User,
556                            "assistant" => Role::Assistant,
557                            _ => Role::User,
558                        },
559                        content: vec![RequestContent::Text {
560                            text: msg.content,
561                            cache_control: None,
562                        }],
563                    })
564                    .collect();
565
566                let params = AnthropicRequest {
567                    model: serializable_request.model,
568                    max_tokens: serializable_request.max_tokens,
569                    messages,
570                    tools: Vec::new(),
571                    thinking: None,
572                    tool_choice: None,
573                    system: None,
574                    metadata: None,
575                    output_config: None,
576                    stop_sequences: Vec::new(),
577                    temperature: None,
578                    top_k: None,
579                    top_p: None,
580                };
581
582                let custom_id = format!("req_hash_{}", hash);
583                let batch_request = anthropic::batches::BatchRequest { custom_id, params };
584
585                // Estimate the serialized size of this request
586                let estimated_size = serde_json::to_string(&batch_request)?.len();
587
588                // If adding this request would exceed the limit, start a new batch
589                if !current_batch_rows.is_empty()
590                    && current_batch_size + estimated_size > MAX_BATCH_SIZE_BYTES
591                {
592                    batches_to_upload.push((current_batch_rows, current_batch_size));
593                    current_batch_rows = Vec::new();
594                    current_batch_size = 0;
595                }
596
597                pending_hashes.insert(hash.clone());
598                current_batch_rows.push((hash, batch_request));
599                current_batch_size += estimated_size;
600                new_rows_added += 1;
601            }
602
603            // If no new rows were added this iteration, all pending requests are already
604            // in current_batch_rows, so we should break to avoid an infinite loop
605            if new_rows_added == 0 {
606                break;
607            }
608
609            // Only upload full batches, keep the partial batch for the next iteration
610            // Upload each sub-batch
611            for (batch_rows, batch_size) in batches_to_upload {
612                let request_hashes: Vec<String> =
613                    batch_rows.iter().map(|(hash, _)| hash.clone()).collect();
614
615                // Remove uploaded hashes from pending set
616                for hash in &request_hashes {
617                    pending_hashes.remove(hash);
618                }
619                let batch_requests: Vec<anthropic::batches::BatchRequest> =
620                    batch_rows.into_iter().map(|(_, req)| req).collect();
621
622                let batch_len = batch_requests.len();
623                log::info!(
624                    "Uploading batch with {} requests (~{:.2} MB)",
625                    batch_len,
626                    batch_size as f64 / (1024.0 * 1024.0)
627                );
628
629                let batch = anthropic::batches::create_batch(
630                    self.http_client.as_ref(),
631                    ANTHROPIC_API_URL,
632                    &self.api_key,
633                    anthropic::batches::CreateBatchRequest {
634                        requests: batch_requests,
635                    },
636                )
637                .await
638                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
639
640                {
641                    let connection = self.connection.lock().unwrap();
642                    connection.with_savepoint("batch_upload", || {
643                        let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
644                        let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
645                        for hash in &request_hashes {
646                            exec((batch.id.as_str(), hash.as_str()))?;
647                        }
648                        Ok(())
649                    })?;
650                }
651
652                total_uploaded += batch_len;
653                log::info!(
654                    "Uploaded batch {} with {} requests ({} total)",
655                    batch.id,
656                    batch_len,
657                    total_uploaded
658                );
659
660                all_batch_ids.push(batch.id);
661            }
662        }
663
664        // Upload any remaining partial batch at the end
665        if !current_batch_rows.is_empty() {
666            let request_hashes: Vec<String> = current_batch_rows
667                .iter()
668                .map(|(hash, _)| hash.clone())
669                .collect();
670            let batch_requests: Vec<anthropic::batches::BatchRequest> =
671                current_batch_rows.into_iter().map(|(_, req)| req).collect();
672
673            let batch_len = batch_requests.len();
674            log::info!(
675                "Uploading final batch with {} requests (~{:.2} MB)",
676                batch_len,
677                current_batch_size as f64 / (1024.0 * 1024.0)
678            );
679
680            let batch = anthropic::batches::create_batch(
681                self.http_client.as_ref(),
682                ANTHROPIC_API_URL,
683                &self.api_key,
684                anthropic::batches::CreateBatchRequest {
685                    requests: batch_requests,
686                },
687            )
688            .await
689            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
690
691            {
692                let connection = self.connection.lock().unwrap();
693                connection.with_savepoint("batch_upload", || {
694                    let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
695                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
696                    for hash in &request_hashes {
697                        exec((batch.id.as_str(), hash.as_str()))?;
698                    }
699                    Ok(())
700                })?;
701            }
702
703            total_uploaded += batch_len;
704            log::info!(
705                "Uploaded batch {} with {} requests ({} total)",
706                batch.id,
707                batch_len,
708                total_uploaded
709            );
710
711            all_batch_ids.push(batch.id);
712        }
713
714        if !all_batch_ids.is_empty() {
715            log::info!(
716                "Finished uploading {} batches with {} total requests",
717                all_batch_ids.len(),
718                total_uploaded
719            );
720        }
721
722        Ok(all_batch_ids)
723    }
724
725    fn request_hash(
726        model: &str,
727        max_tokens: u64,
728        messages: &[Message],
729        seed: Option<usize>,
730    ) -> String {
731        let mut hasher = std::hash::DefaultHasher::new();
732        model.hash(&mut hasher);
733        max_tokens.hash(&mut hasher);
734        for msg in messages {
735            message_content_to_string(&msg.content).hash(&mut hasher);
736        }
737        if let Some(seed) = seed {
738            seed.hash(&mut hasher);
739        }
740        let request_hash = hasher.finish();
741        format!("{request_hash:016x}")
742    }
743}
744
745fn message_content_to_string(content: &[RequestContent]) -> String {
746    content
747        .iter()
748        .filter_map(|c| match c {
749            RequestContent::Text { text, .. } => Some(text.clone()),
750            _ => None,
751        })
752        .collect::<Vec<String>>()
753        .join("\n")
754}
755
756pub enum AnthropicClient {
757    // No batching
758    Plain(PlainLlmClient),
759    Batch(BatchingLlmClient),
760    Dummy,
761}
762
763impl AnthropicClient {
764    pub fn plain() -> Result<Self> {
765        Ok(Self::Plain(PlainLlmClient::new()?))
766    }
767
768    pub fn batch(cache_path: &Path) -> Result<Self> {
769        Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
770    }
771
772    #[allow(dead_code)]
773    pub fn dummy() -> Self {
774        Self::Dummy
775    }
776
777    pub async fn generate(
778        &self,
779        model: &str,
780        max_tokens: u64,
781        messages: Vec<Message>,
782        seed: Option<usize>,
783        cache_only: bool,
784    ) -> Result<Option<AnthropicResponse>> {
785        match self {
786            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
787                .generate(model, max_tokens, messages)
788                .await
789                .map(Some),
790            AnthropicClient::Batch(batching_llm_client) => {
791                batching_llm_client
792                    .generate(model, max_tokens, messages, seed, cache_only)
793                    .await
794            }
795            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
796        }
797    }
798
799    #[allow(dead_code)]
800    pub async fn generate_streaming<F>(
801        &self,
802        model: &str,
803        max_tokens: u64,
804        messages: Vec<Message>,
805        on_progress: F,
806    ) -> Result<Option<AnthropicResponse>>
807    where
808        F: FnMut(usize, &str),
809    {
810        match self {
811            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
812                .generate_streaming(model, max_tokens, messages, on_progress)
813                .await
814                .map(Some),
815            AnthropicClient::Batch(_) => {
816                anyhow::bail!("Streaming not supported with batching client")
817            }
818            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
819        }
820    }
821
822    pub async fn sync_batches(&self) -> Result<()> {
823        match self {
824            AnthropicClient::Plain(_) => Ok(()),
825            AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
826            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
827        }
828    }
829
830    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
831        match self {
832            AnthropicClient::Plain(_) => {
833                anyhow::bail!("Import batches is only supported with batching client")
834            }
835            AnthropicClient::Batch(batching_llm_client) => {
836                batching_llm_client.import_batches(batch_ids).await
837            }
838            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
839        }
840    }
841}