1use anthropic::{
2 ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
3 Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
4 stream_completion,
5};
6use anyhow::Result;
7use futures::StreamExt as _;
8use http_client::HttpClient;
9use indoc::indoc;
10use reqwest_client::ReqwestClient;
11use sqlez::bindable::Bind;
12use sqlez::bindable::StaticColumnCount;
13use sqlez_macros::sql;
14use std::collections::HashSet;
15use std::hash::Hash;
16use std::hash::Hasher;
17use std::path::Path;
18use std::sync::{Arc, Mutex};
19
20pub struct PlainLlmClient {
21 pub http_client: Arc<dyn HttpClient>,
22 pub api_key: String,
23}
24
25impl PlainLlmClient {
26 pub fn new() -> Result<Self> {
27 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
28 let api_key = std::env::var("ANTHROPIC_API_KEY")
29 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
30 Ok(Self {
31 http_client,
32 api_key,
33 })
34 }
35
36 pub async fn generate(
37 &self,
38 model: &str,
39 max_tokens: u64,
40 messages: Vec<Message>,
41 ) -> Result<AnthropicResponse> {
42 let request = AnthropicRequest {
43 model: model.to_string(),
44 max_tokens,
45 messages,
46 tools: Vec::new(),
47 thinking: None,
48 tool_choice: None,
49 system: None,
50 metadata: None,
51 output_config: None,
52 stop_sequences: Vec::new(),
53 speed: None,
54 temperature: None,
55 top_k: None,
56 top_p: None,
57 };
58
59 let response = non_streaming_completion(
60 self.http_client.as_ref(),
61 ANTHROPIC_API_URL,
62 &self.api_key,
63 request,
64 None,
65 )
66 .await
67 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
68
69 Ok(response)
70 }
71
72 pub async fn generate_streaming<F>(
73 &self,
74 model: &str,
75 max_tokens: u64,
76 messages: Vec<Message>,
77 mut on_progress: F,
78 ) -> Result<AnthropicResponse>
79 where
80 F: FnMut(usize, &str),
81 {
82 let request = AnthropicRequest {
83 model: model.to_string(),
84 max_tokens,
85 messages,
86 tools: Vec::new(),
87 thinking: None,
88 tool_choice: None,
89 system: None,
90 metadata: None,
91 output_config: None,
92 stop_sequences: Vec::new(),
93 speed: None,
94 temperature: None,
95 top_k: None,
96 top_p: None,
97 };
98
99 let mut stream = stream_completion(
100 self.http_client.as_ref(),
101 ANTHROPIC_API_URL,
102 &self.api_key,
103 request,
104 None,
105 )
106 .await
107 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
108
109 let mut response: Option<AnthropicResponse> = None;
110 let mut text_content = String::new();
111
112 while let Some(event_result) = stream.next().await {
113 let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
114
115 match event {
116 Event::MessageStart { message } => {
117 response = Some(message);
118 }
119 Event::ContentBlockDelta { delta, .. } => {
120 if let anthropic::ContentDelta::TextDelta { text } = delta {
121 text_content.push_str(&text);
122 on_progress(text_content.len(), &text_content);
123 }
124 }
125 _ => {}
126 }
127 }
128
129 let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
130
131 if response.content.is_empty() && !text_content.is_empty() {
132 response
133 .content
134 .push(ResponseContent::Text { text: text_content });
135 }
136
137 Ok(response)
138 }
139}
140
141pub struct BatchingLlmClient {
142 connection: Mutex<sqlez::connection::Connection>,
143 http_client: Arc<dyn HttpClient>,
144 api_key: String,
145}
146
147struct CacheRow {
148 request_hash: String,
149 request: Option<String>,
150 response: Option<String>,
151 batch_id: Option<String>,
152}
153
154impl StaticColumnCount for CacheRow {
155 fn column_count() -> usize {
156 4
157 }
158}
159
160impl Bind for CacheRow {
161 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
162 let next_index = statement.bind(&self.request_hash, start_index)?;
163 let next_index = statement.bind(&self.request, next_index)?;
164 let next_index = statement.bind(&self.response, next_index)?;
165 let next_index = statement.bind(&self.batch_id, next_index)?;
166 Ok(next_index)
167 }
168}
169
170#[derive(serde::Serialize, serde::Deserialize)]
171struct SerializableRequest {
172 model: String,
173 max_tokens: u64,
174 messages: Vec<SerializableMessage>,
175}
176
177#[derive(serde::Serialize, serde::Deserialize)]
178struct SerializableMessage {
179 role: String,
180 content: String,
181}
182
183impl BatchingLlmClient {
184 fn new(cache_path: &Path) -> Result<Self> {
185 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
186 let api_key = std::env::var("ANTHROPIC_API_KEY")
187 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
188
189 let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
190 let mut statement = sqlez::statement::Statement::prepare(
191 &connection,
192 indoc! {"
193 CREATE TABLE IF NOT EXISTS cache (
194 request_hash TEXT PRIMARY KEY,
195 request TEXT,
196 response TEXT,
197 batch_id TEXT
198 );
199 "},
200 )?;
201 statement.exec()?;
202 drop(statement);
203
204 Ok(Self {
205 connection: Mutex::new(connection),
206 http_client,
207 api_key,
208 })
209 }
210
211 pub fn lookup(
212 &self,
213 model: &str,
214 max_tokens: u64,
215 messages: &[Message],
216 seed: Option<usize>,
217 ) -> Result<Option<AnthropicResponse>> {
218 let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
219 let connection = self.connection.lock().unwrap();
220 let response: Vec<String> = connection.select_bound(
221 &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
222 )?(request_hash_str.as_str())?;
223 Ok(response
224 .into_iter()
225 .next()
226 .and_then(|text| serde_json::from_str(&text).ok()))
227 }
228
229 pub fn mark_for_batch(
230 &self,
231 model: &str,
232 max_tokens: u64,
233 messages: &[Message],
234 seed: Option<usize>,
235 ) -> Result<()> {
236 let request_hash = Self::request_hash(model, max_tokens, messages, seed);
237
238 let serializable_messages: Vec<SerializableMessage> = messages
239 .iter()
240 .map(|msg| SerializableMessage {
241 role: match msg.role {
242 Role::User => "user".to_string(),
243 Role::Assistant => "assistant".to_string(),
244 },
245 content: message_content_to_string(&msg.content),
246 })
247 .collect();
248
249 let serializable_request = SerializableRequest {
250 model: model.to_string(),
251 max_tokens,
252 messages: serializable_messages,
253 };
254
255 let request = Some(serde_json::to_string(&serializable_request)?);
256 let cache_row = CacheRow {
257 request_hash,
258 request,
259 response: None,
260 batch_id: None,
261 };
262 let connection = self.connection.lock().unwrap();
263 connection.exec_bound::<CacheRow>(sql!(
264 INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
265 cache_row,
266 )
267 }
268
269 async fn generate(
270 &self,
271 model: &str,
272 max_tokens: u64,
273 messages: Vec<Message>,
274 seed: Option<usize>,
275 cache_only: bool,
276 ) -> Result<Option<AnthropicResponse>> {
277 let response = self.lookup(model, max_tokens, &messages, seed)?;
278 if let Some(response) = response {
279 return Ok(Some(response));
280 }
281
282 if !cache_only {
283 self.mark_for_batch(model, max_tokens, &messages, seed)?;
284 }
285
286 Ok(None)
287 }
288
289 /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
290 async fn sync_batches(&self) -> Result<()> {
291 let _batch_ids = self.upload_pending_requests().await?;
292 self.download_finished_batches().await
293 }
294
295 pub fn pending_batch_count(&self) -> Result<usize> {
296 let connection = self.connection.lock().unwrap();
297 let counts: Vec<i32> = connection.select(
298 sql!(SELECT COUNT(*) FROM cache WHERE batch_id IS NOT NULL AND response IS NULL),
299 )?()?;
300 Ok(counts.into_iter().next().unwrap_or(0) as usize)
301 }
302
303 /// Import batch results from external batch IDs (useful for recovering after database loss)
304 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
305 for batch_id in batch_ids {
306 log::info!("Importing batch {}", batch_id);
307
308 let batch_status = anthropic::batches::retrieve_batch(
309 self.http_client.as_ref(),
310 ANTHROPIC_API_URL,
311 &self.api_key,
312 batch_id,
313 )
314 .await
315 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
316
317 log::info!(
318 "Batch {} status: {}",
319 batch_id,
320 batch_status.processing_status
321 );
322
323 if batch_status.processing_status != "ended" {
324 log::warn!(
325 "Batch {} is not finished (status: {}), skipping",
326 batch_id,
327 batch_status.processing_status
328 );
329 continue;
330 }
331
332 let results = anthropic::batches::retrieve_batch_results(
333 self.http_client.as_ref(),
334 ANTHROPIC_API_URL,
335 &self.api_key,
336 batch_id,
337 )
338 .await
339 .map_err(|e| {
340 anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
341 })?;
342
343 let mut updates: Vec<(String, String, String)> = Vec::new();
344 let mut success_count = 0;
345 let mut error_count = 0;
346
347 for result in results {
348 let request_hash = result
349 .custom_id
350 .strip_prefix("req_hash_")
351 .unwrap_or(&result.custom_id)
352 .to_string();
353
354 match result.result {
355 anthropic::batches::BatchResult::Succeeded { message } => {
356 let response_json = serde_json::to_string(&message)?;
357 updates.push((request_hash, response_json, batch_id.clone()));
358 success_count += 1;
359 }
360 anthropic::batches::BatchResult::Errored { error } => {
361 log::error!(
362 "Batch request {} failed: {}: {}",
363 request_hash,
364 error.error.error_type,
365 error.error.message
366 );
367 let error_json = serde_json::json!({
368 "error": {
369 "type": error.error.error_type,
370 "message": error.error.message
371 }
372 })
373 .to_string();
374 updates.push((request_hash, error_json, batch_id.clone()));
375 error_count += 1;
376 }
377 anthropic::batches::BatchResult::Canceled => {
378 log::warn!("Batch request {} was canceled", request_hash);
379 error_count += 1;
380 }
381 anthropic::batches::BatchResult::Expired => {
382 log::warn!("Batch request {} expired", request_hash);
383 error_count += 1;
384 }
385 }
386 }
387
388 let connection = self.connection.lock().unwrap();
389 connection.with_savepoint("batch_import", || {
390 // Use INSERT OR REPLACE to handle both new entries and updating existing ones
391 let q = sql!(
392 INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
393 VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
394 );
395 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
396 for (request_hash, response_json, batch_id) in &updates {
397 exec((
398 request_hash.as_str(),
399 request_hash.as_str(),
400 response_json.as_str(),
401 batch_id.as_str(),
402 ))?;
403 }
404 Ok(())
405 })?;
406
407 log::info!(
408 "Imported batch {}: {} successful, {} errors",
409 batch_id,
410 success_count,
411 error_count
412 );
413 }
414
415 Ok(())
416 }
417
418 async fn download_finished_batches(&self) -> Result<()> {
419 let batch_ids: Vec<String> = {
420 let connection = self.connection.lock().unwrap();
421 let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
422 connection.select(q)?()?
423 };
424
425 for batch_id in &batch_ids {
426 let batch_status = anthropic::batches::retrieve_batch(
427 self.http_client.as_ref(),
428 ANTHROPIC_API_URL,
429 &self.api_key,
430 &batch_id,
431 )
432 .await
433 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
434
435 log::info!(
436 "Batch {} status: {}",
437 batch_id,
438 batch_status.processing_status
439 );
440
441 if batch_status.processing_status == "ended" {
442 let results = anthropic::batches::retrieve_batch_results(
443 self.http_client.as_ref(),
444 ANTHROPIC_API_URL,
445 &self.api_key,
446 &batch_id,
447 )
448 .await
449 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
450
451 let mut updates: Vec<(String, String)> = Vec::new();
452 let mut success_count = 0;
453 for result in results {
454 let request_hash = result
455 .custom_id
456 .strip_prefix("req_hash_")
457 .unwrap_or(&result.custom_id)
458 .to_string();
459
460 match result.result {
461 anthropic::batches::BatchResult::Succeeded { message } => {
462 let response_json = serde_json::to_string(&message)?;
463 updates.push((response_json, request_hash));
464 success_count += 1;
465 }
466 anthropic::batches::BatchResult::Errored { error } => {
467 log::error!(
468 "Batch request {} failed: {}: {}",
469 request_hash,
470 error.error.error_type,
471 error.error.message
472 );
473 let error_json = serde_json::json!({
474 "error": {
475 "type": error.error.error_type,
476 "message": error.error.message
477 }
478 })
479 .to_string();
480 updates.push((error_json, request_hash));
481 }
482 anthropic::batches::BatchResult::Canceled => {
483 log::warn!("Batch request {} was canceled", request_hash);
484 let error_json = serde_json::json!({
485 "error": {
486 "type": "canceled",
487 "message": "Batch request was canceled"
488 }
489 })
490 .to_string();
491 updates.push((error_json, request_hash));
492 }
493 anthropic::batches::BatchResult::Expired => {
494 log::warn!("Batch request {} expired", request_hash);
495 let error_json = serde_json::json!({
496 "error": {
497 "type": "expired",
498 "message": "Batch request expired"
499 }
500 })
501 .to_string();
502 updates.push((error_json, request_hash));
503 }
504 }
505 }
506
507 let connection = self.connection.lock().unwrap();
508 connection.with_savepoint("batch_download", || {
509 let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
510 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
511 for (response_json, request_hash) in &updates {
512 exec((response_json.as_str(), request_hash.as_str()))?;
513 }
514 Ok(())
515 })?;
516 log::info!("Downloaded {} successful requests", success_count);
517 }
518 }
519
520 Ok(())
521 }
522
523 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
524 const BATCH_CHUNK_SIZE: i32 = 16_000;
525 const MAX_BATCH_SIZE_BYTES: usize = 100 * 1024 * 1024;
526
527 let mut all_batch_ids = Vec::new();
528 let mut total_uploaded = 0;
529
530 let mut current_batch_rows = Vec::new();
531 let mut current_batch_size = 0usize;
532 let mut pending_hashes: HashSet<String> = HashSet::new();
533 loop {
534 let rows: Vec<(String, String)> = {
535 let connection = self.connection.lock().unwrap();
536 let q = sql!(
537 SELECT request_hash, request FROM cache
538 WHERE batch_id IS NULL AND response IS NULL
539 LIMIT ?
540 );
541 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
542 };
543
544 if rows.is_empty() {
545 break;
546 }
547
548 // Split rows into sub-batches based on size
549 let mut batches_to_upload = Vec::new();
550 let mut new_rows_added = 0;
551
552 for row in rows {
553 let (hash, request_str) = row;
554
555 // Skip rows already added to current_batch_rows but not yet uploaded
556 if pending_hashes.contains(&hash) {
557 continue;
558 }
559 let serializable_request: SerializableRequest = serde_json::from_str(&request_str)?;
560
561 let messages: Vec<Message> = serializable_request
562 .messages
563 .into_iter()
564 .map(|msg| Message {
565 role: match msg.role.as_str() {
566 "user" => Role::User,
567 "assistant" => Role::Assistant,
568 _ => Role::User,
569 },
570 content: vec![RequestContent::Text {
571 text: msg.content,
572 cache_control: None,
573 }],
574 })
575 .collect();
576
577 let params = AnthropicRequest {
578 model: serializable_request.model,
579 max_tokens: serializable_request.max_tokens,
580 messages,
581 tools: Vec::new(),
582 thinking: None,
583 tool_choice: None,
584 system: None,
585 metadata: None,
586 output_config: None,
587 stop_sequences: Vec::new(),
588 temperature: None,
589 top_k: None,
590 top_p: None,
591 speed: None,
592 };
593
594 let custom_id = format!("req_hash_{}", hash);
595 let batch_request = anthropic::batches::BatchRequest { custom_id, params };
596
597 // Estimate the serialized size of this request
598 let estimated_size = serde_json::to_string(&batch_request)?.len();
599
600 // If adding this request would exceed the limit, start a new batch
601 if !current_batch_rows.is_empty()
602 && current_batch_size + estimated_size > MAX_BATCH_SIZE_BYTES
603 {
604 batches_to_upload.push((current_batch_rows, current_batch_size));
605 current_batch_rows = Vec::new();
606 current_batch_size = 0;
607 }
608
609 pending_hashes.insert(hash.clone());
610 current_batch_rows.push((hash, batch_request));
611 current_batch_size += estimated_size;
612 new_rows_added += 1;
613 }
614
615 // If no new rows were added this iteration, all pending requests are already
616 // in current_batch_rows, so we should break to avoid an infinite loop
617 if new_rows_added == 0 {
618 break;
619 }
620
621 // Only upload full batches, keep the partial batch for the next iteration
622 // Upload each sub-batch
623 for (batch_rows, batch_size) in batches_to_upload {
624 let request_hashes: Vec<String> =
625 batch_rows.iter().map(|(hash, _)| hash.clone()).collect();
626
627 // Remove uploaded hashes from pending set
628 for hash in &request_hashes {
629 pending_hashes.remove(hash);
630 }
631 let batch_requests: Vec<anthropic::batches::BatchRequest> =
632 batch_rows.into_iter().map(|(_, req)| req).collect();
633
634 let batch_len = batch_requests.len();
635 log::info!(
636 "Uploading batch with {} requests (~{:.2} MB)",
637 batch_len,
638 batch_size as f64 / (1024.0 * 1024.0)
639 );
640
641 let batch = anthropic::batches::create_batch(
642 self.http_client.as_ref(),
643 ANTHROPIC_API_URL,
644 &self.api_key,
645 anthropic::batches::CreateBatchRequest {
646 requests: batch_requests,
647 },
648 )
649 .await
650 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
651
652 {
653 let connection = self.connection.lock().unwrap();
654 connection.with_savepoint("batch_upload", || {
655 let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
656 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
657 for hash in &request_hashes {
658 exec((batch.id.as_str(), hash.as_str()))?;
659 }
660 Ok(())
661 })?;
662 }
663
664 total_uploaded += batch_len;
665 log::info!(
666 "Uploaded batch {} with {} requests ({} total)",
667 batch.id,
668 batch_len,
669 total_uploaded
670 );
671
672 all_batch_ids.push(batch.id);
673 }
674 }
675
676 // Upload any remaining partial batch at the end
677 if !current_batch_rows.is_empty() {
678 let request_hashes: Vec<String> = current_batch_rows
679 .iter()
680 .map(|(hash, _)| hash.clone())
681 .collect();
682 let batch_requests: Vec<anthropic::batches::BatchRequest> =
683 current_batch_rows.into_iter().map(|(_, req)| req).collect();
684
685 let batch_len = batch_requests.len();
686 log::info!(
687 "Uploading final batch with {} requests (~{:.2} MB)",
688 batch_len,
689 current_batch_size as f64 / (1024.0 * 1024.0)
690 );
691
692 let batch = anthropic::batches::create_batch(
693 self.http_client.as_ref(),
694 ANTHROPIC_API_URL,
695 &self.api_key,
696 anthropic::batches::CreateBatchRequest {
697 requests: batch_requests,
698 },
699 )
700 .await
701 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
702
703 {
704 let connection = self.connection.lock().unwrap();
705 connection.with_savepoint("batch_upload", || {
706 let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
707 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
708 for hash in &request_hashes {
709 exec((batch.id.as_str(), hash.as_str()))?;
710 }
711 Ok(())
712 })?;
713 }
714
715 total_uploaded += batch_len;
716 log::info!(
717 "Uploaded batch {} with {} requests ({} total)",
718 batch.id,
719 batch_len,
720 total_uploaded
721 );
722
723 all_batch_ids.push(batch.id);
724 }
725
726 if !all_batch_ids.is_empty() {
727 log::info!(
728 "Finished uploading {} batches with {} total requests",
729 all_batch_ids.len(),
730 total_uploaded
731 );
732 }
733
734 Ok(all_batch_ids)
735 }
736
737 fn request_hash(
738 model: &str,
739 max_tokens: u64,
740 messages: &[Message],
741 seed: Option<usize>,
742 ) -> String {
743 let mut hasher = std::hash::DefaultHasher::new();
744 model.hash(&mut hasher);
745 max_tokens.hash(&mut hasher);
746 for msg in messages {
747 message_content_to_string(&msg.content).hash(&mut hasher);
748 }
749 if let Some(seed) = seed {
750 seed.hash(&mut hasher);
751 }
752 let request_hash = hasher.finish();
753 format!("{request_hash:016x}")
754 }
755}
756
757fn message_content_to_string(content: &[RequestContent]) -> String {
758 content
759 .iter()
760 .filter_map(|c| match c {
761 RequestContent::Text { text, .. } => Some(text.clone()),
762 _ => None,
763 })
764 .collect::<Vec<String>>()
765 .join("\n")
766}
767
768pub enum AnthropicClient {
769 // No batching
770 Plain(PlainLlmClient),
771 Batch(BatchingLlmClient),
772 Dummy,
773}
774
775impl AnthropicClient {
776 pub fn plain() -> Result<Self> {
777 Ok(Self::Plain(PlainLlmClient::new()?))
778 }
779
780 pub fn batch(cache_path: &Path) -> Result<Self> {
781 Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
782 }
783
784 #[allow(dead_code)]
785 pub fn dummy() -> Self {
786 Self::Dummy
787 }
788
789 pub async fn generate(
790 &self,
791 model: &str,
792 max_tokens: u64,
793 messages: Vec<Message>,
794 seed: Option<usize>,
795 cache_only: bool,
796 ) -> Result<Option<AnthropicResponse>> {
797 match self {
798 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
799 .generate(model, max_tokens, messages)
800 .await
801 .map(Some),
802 AnthropicClient::Batch(batching_llm_client) => {
803 batching_llm_client
804 .generate(model, max_tokens, messages, seed, cache_only)
805 .await
806 }
807 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
808 }
809 }
810
811 #[allow(dead_code)]
812 pub async fn generate_streaming<F>(
813 &self,
814 model: &str,
815 max_tokens: u64,
816 messages: Vec<Message>,
817 on_progress: F,
818 ) -> Result<Option<AnthropicResponse>>
819 where
820 F: FnMut(usize, &str),
821 {
822 match self {
823 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
824 .generate_streaming(model, max_tokens, messages, on_progress)
825 .await
826 .map(Some),
827 AnthropicClient::Batch(_) => {
828 anyhow::bail!("Streaming not supported with batching client")
829 }
830 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
831 }
832 }
833
834 pub async fn sync_batches(&self) -> Result<()> {
835 match self {
836 AnthropicClient::Plain(_) => Ok(()),
837 AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
838 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
839 }
840 }
841
842 pub fn pending_batch_count(&self) -> Result<usize> {
843 match self {
844 AnthropicClient::Plain(_) => Ok(0),
845 AnthropicClient::Batch(batching_llm_client) => {
846 batching_llm_client.pending_batch_count()
847 }
848 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
849 }
850 }
851
852 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
853 match self {
854 AnthropicClient::Plain(_) => {
855 anyhow::bail!("Import batches is only supported with batching client")
856 }
857 AnthropicClient::Batch(batching_llm_client) => {
858 batching_llm_client.import_batches(batch_ids).await
859 }
860 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
861 }
862 }
863}