1use anthropic::{
2 ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
3 Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
4 stream_completion,
5};
6use anyhow::Result;
7use futures::StreamExt as _;
8use http_client::HttpClient;
9use indoc::indoc;
10use reqwest_client::ReqwestClient;
11use sqlez::bindable::Bind;
12use sqlez::bindable::StaticColumnCount;
13use sqlez_macros::sql;
14use std::hash::Hash;
15use std::hash::Hasher;
16use std::path::Path;
17use std::sync::{Arc, Mutex};
18
19pub struct PlainLlmClient {
20 pub http_client: Arc<dyn HttpClient>,
21 pub api_key: String,
22}
23
24impl PlainLlmClient {
25 pub fn new() -> Result<Self> {
26 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
27 let api_key = std::env::var("ANTHROPIC_API_KEY")
28 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
29 Ok(Self {
30 http_client,
31 api_key,
32 })
33 }
34
35 pub async fn generate(
36 &self,
37 model: &str,
38 max_tokens: u64,
39 messages: Vec<Message>,
40 ) -> Result<AnthropicResponse> {
41 let request = AnthropicRequest {
42 model: model.to_string(),
43 max_tokens,
44 messages,
45 tools: Vec::new(),
46 thinking: None,
47 tool_choice: None,
48 system: None,
49 metadata: None,
50 output_config: None,
51 stop_sequences: Vec::new(),
52 temperature: None,
53 top_k: None,
54 top_p: None,
55 };
56
57 let response = non_streaming_completion(
58 self.http_client.as_ref(),
59 ANTHROPIC_API_URL,
60 &self.api_key,
61 request,
62 None,
63 )
64 .await
65 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
66
67 Ok(response)
68 }
69
70 pub async fn generate_streaming<F>(
71 &self,
72 model: &str,
73 max_tokens: u64,
74 messages: Vec<Message>,
75 mut on_progress: F,
76 ) -> Result<AnthropicResponse>
77 where
78 F: FnMut(usize, &str),
79 {
80 let request = AnthropicRequest {
81 model: model.to_string(),
82 max_tokens,
83 messages,
84 tools: Vec::new(),
85 thinking: None,
86 tool_choice: None,
87 system: None,
88 metadata: None,
89 output_config: None,
90 stop_sequences: Vec::new(),
91 temperature: None,
92 top_k: None,
93 top_p: None,
94 };
95
96 let mut stream = stream_completion(
97 self.http_client.as_ref(),
98 ANTHROPIC_API_URL,
99 &self.api_key,
100 request,
101 None,
102 )
103 .await
104 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
105
106 let mut response: Option<AnthropicResponse> = None;
107 let mut text_content = String::new();
108
109 while let Some(event_result) = stream.next().await {
110 let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
111
112 match event {
113 Event::MessageStart { message } => {
114 response = Some(message);
115 }
116 Event::ContentBlockDelta { delta, .. } => {
117 if let anthropic::ContentDelta::TextDelta { text } = delta {
118 text_content.push_str(&text);
119 on_progress(text_content.len(), &text_content);
120 }
121 }
122 _ => {}
123 }
124 }
125
126 let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
127
128 if response.content.is_empty() && !text_content.is_empty() {
129 response
130 .content
131 .push(ResponseContent::Text { text: text_content });
132 }
133
134 Ok(response)
135 }
136}
137
138pub struct BatchingLlmClient {
139 connection: Mutex<sqlez::connection::Connection>,
140 http_client: Arc<dyn HttpClient>,
141 api_key: String,
142}
143
144struct CacheRow {
145 request_hash: String,
146 request: Option<String>,
147 response: Option<String>,
148 batch_id: Option<String>,
149}
150
151impl StaticColumnCount for CacheRow {
152 fn column_count() -> usize {
153 4
154 }
155}
156
157impl Bind for CacheRow {
158 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
159 let next_index = statement.bind(&self.request_hash, start_index)?;
160 let next_index = statement.bind(&self.request, next_index)?;
161 let next_index = statement.bind(&self.response, next_index)?;
162 let next_index = statement.bind(&self.batch_id, next_index)?;
163 Ok(next_index)
164 }
165}
166
167#[derive(serde::Serialize, serde::Deserialize)]
168struct SerializableRequest {
169 model: String,
170 max_tokens: u64,
171 messages: Vec<SerializableMessage>,
172}
173
174#[derive(serde::Serialize, serde::Deserialize)]
175struct SerializableMessage {
176 role: String,
177 content: String,
178}
179
180impl BatchingLlmClient {
181 fn new(cache_path: &Path) -> Result<Self> {
182 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
183 let api_key = std::env::var("ANTHROPIC_API_KEY")
184 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
185
186 let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
187 let mut statement = sqlez::statement::Statement::prepare(
188 &connection,
189 indoc! {"
190 CREATE TABLE IF NOT EXISTS cache (
191 request_hash TEXT PRIMARY KEY,
192 request TEXT,
193 response TEXT,
194 batch_id TEXT
195 );
196 "},
197 )?;
198 statement.exec()?;
199 drop(statement);
200
201 Ok(Self {
202 connection: Mutex::new(connection),
203 http_client,
204 api_key,
205 })
206 }
207
208 pub fn lookup(
209 &self,
210 model: &str,
211 max_tokens: u64,
212 messages: &[Message],
213 seed: Option<usize>,
214 ) -> Result<Option<AnthropicResponse>> {
215 let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
216 let connection = self.connection.lock().unwrap();
217 let response: Vec<String> = connection.select_bound(
218 &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
219 )?(request_hash_str.as_str())?;
220 Ok(response
221 .into_iter()
222 .next()
223 .and_then(|text| serde_json::from_str(&text).ok()))
224 }
225
226 pub fn mark_for_batch(
227 &self,
228 model: &str,
229 max_tokens: u64,
230 messages: &[Message],
231 seed: Option<usize>,
232 ) -> Result<()> {
233 let request_hash = Self::request_hash(model, max_tokens, messages, seed);
234
235 let serializable_messages: Vec<SerializableMessage> = messages
236 .iter()
237 .map(|msg| SerializableMessage {
238 role: match msg.role {
239 Role::User => "user".to_string(),
240 Role::Assistant => "assistant".to_string(),
241 },
242 content: message_content_to_string(&msg.content),
243 })
244 .collect();
245
246 let serializable_request = SerializableRequest {
247 model: model.to_string(),
248 max_tokens,
249 messages: serializable_messages,
250 };
251
252 let request = Some(serde_json::to_string(&serializable_request)?);
253 let cache_row = CacheRow {
254 request_hash,
255 request,
256 response: None,
257 batch_id: None,
258 };
259 let connection = self.connection.lock().unwrap();
260 connection.exec_bound::<CacheRow>(sql!(
261 INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
262 cache_row,
263 )
264 }
265
266 async fn generate(
267 &self,
268 model: &str,
269 max_tokens: u64,
270 messages: Vec<Message>,
271 seed: Option<usize>,
272 cache_only: bool,
273 ) -> Result<Option<AnthropicResponse>> {
274 let response = self.lookup(model, max_tokens, &messages, seed)?;
275 if let Some(response) = response {
276 return Ok(Some(response));
277 }
278
279 if !cache_only {
280 self.mark_for_batch(model, max_tokens, &messages, seed)?;
281 }
282
283 Ok(None)
284 }
285
286 /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
287 async fn sync_batches(&self) -> Result<()> {
288 let _batch_ids = self.upload_pending_requests().await?;
289 self.download_finished_batches().await
290 }
291
292 /// Import batch results from external batch IDs (useful for recovering after database loss)
293 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
294 for batch_id in batch_ids {
295 log::info!("Importing batch {}", batch_id);
296
297 let batch_status = anthropic::batches::retrieve_batch(
298 self.http_client.as_ref(),
299 ANTHROPIC_API_URL,
300 &self.api_key,
301 batch_id,
302 )
303 .await
304 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
305
306 log::info!(
307 "Batch {} status: {}",
308 batch_id,
309 batch_status.processing_status
310 );
311
312 if batch_status.processing_status != "ended" {
313 log::warn!(
314 "Batch {} is not finished (status: {}), skipping",
315 batch_id,
316 batch_status.processing_status
317 );
318 continue;
319 }
320
321 let results = anthropic::batches::retrieve_batch_results(
322 self.http_client.as_ref(),
323 ANTHROPIC_API_URL,
324 &self.api_key,
325 batch_id,
326 )
327 .await
328 .map_err(|e| {
329 anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
330 })?;
331
332 let mut updates: Vec<(String, String, String)> = Vec::new();
333 let mut success_count = 0;
334 let mut error_count = 0;
335
336 for result in results {
337 let request_hash = result
338 .custom_id
339 .strip_prefix("req_hash_")
340 .unwrap_or(&result.custom_id)
341 .to_string();
342
343 match result.result {
344 anthropic::batches::BatchResult::Succeeded { message } => {
345 let response_json = serde_json::to_string(&message)?;
346 updates.push((request_hash, response_json, batch_id.clone()));
347 success_count += 1;
348 }
349 anthropic::batches::BatchResult::Errored { error } => {
350 log::error!(
351 "Batch request {} failed: {}: {}",
352 request_hash,
353 error.error.error_type,
354 error.error.message
355 );
356 let error_json = serde_json::json!({
357 "error": {
358 "type": error.error.error_type,
359 "message": error.error.message
360 }
361 })
362 .to_string();
363 updates.push((request_hash, error_json, batch_id.clone()));
364 error_count += 1;
365 }
366 anthropic::batches::BatchResult::Canceled => {
367 log::warn!("Batch request {} was canceled", request_hash);
368 error_count += 1;
369 }
370 anthropic::batches::BatchResult::Expired => {
371 log::warn!("Batch request {} expired", request_hash);
372 error_count += 1;
373 }
374 }
375 }
376
377 let connection = self.connection.lock().unwrap();
378 connection.with_savepoint("batch_import", || {
379 // Use INSERT OR REPLACE to handle both new entries and updating existing ones
380 let q = sql!(
381 INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
382 VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
383 );
384 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
385 for (request_hash, response_json, batch_id) in &updates {
386 exec((
387 request_hash.as_str(),
388 request_hash.as_str(),
389 response_json.as_str(),
390 batch_id.as_str(),
391 ))?;
392 }
393 Ok(())
394 })?;
395
396 log::info!(
397 "Imported batch {}: {} successful, {} errors",
398 batch_id,
399 success_count,
400 error_count
401 );
402 }
403
404 Ok(())
405 }
406
407 async fn download_finished_batches(&self) -> Result<()> {
408 let batch_ids: Vec<String> = {
409 let connection = self.connection.lock().unwrap();
410 let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
411 connection.select(q)?()?
412 };
413
414 for batch_id in &batch_ids {
415 let batch_status = anthropic::batches::retrieve_batch(
416 self.http_client.as_ref(),
417 ANTHROPIC_API_URL,
418 &self.api_key,
419 &batch_id,
420 )
421 .await
422 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
423
424 log::info!(
425 "Batch {} status: {}",
426 batch_id,
427 batch_status.processing_status
428 );
429
430 if batch_status.processing_status == "ended" {
431 let results = anthropic::batches::retrieve_batch_results(
432 self.http_client.as_ref(),
433 ANTHROPIC_API_URL,
434 &self.api_key,
435 &batch_id,
436 )
437 .await
438 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
439
440 let mut updates: Vec<(String, String)> = Vec::new();
441 let mut success_count = 0;
442 for result in results {
443 let request_hash = result
444 .custom_id
445 .strip_prefix("req_hash_")
446 .unwrap_or(&result.custom_id)
447 .to_string();
448
449 match result.result {
450 anthropic::batches::BatchResult::Succeeded { message } => {
451 let response_json = serde_json::to_string(&message)?;
452 updates.push((response_json, request_hash));
453 success_count += 1;
454 }
455 anthropic::batches::BatchResult::Errored { error } => {
456 log::error!(
457 "Batch request {} failed: {}: {}",
458 request_hash,
459 error.error.error_type,
460 error.error.message
461 );
462 let error_json = serde_json::json!({
463 "error": {
464 "type": error.error.error_type,
465 "message": error.error.message
466 }
467 })
468 .to_string();
469 updates.push((error_json, request_hash));
470 }
471 anthropic::batches::BatchResult::Canceled => {
472 log::warn!("Batch request {} was canceled", request_hash);
473 let error_json = serde_json::json!({
474 "error": {
475 "type": "canceled",
476 "message": "Batch request was canceled"
477 }
478 })
479 .to_string();
480 updates.push((error_json, request_hash));
481 }
482 anthropic::batches::BatchResult::Expired => {
483 log::warn!("Batch request {} expired", request_hash);
484 let error_json = serde_json::json!({
485 "error": {
486 "type": "expired",
487 "message": "Batch request expired"
488 }
489 })
490 .to_string();
491 updates.push((error_json, request_hash));
492 }
493 }
494 }
495
496 let connection = self.connection.lock().unwrap();
497 connection.with_savepoint("batch_download", || {
498 let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
499 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
500 for (response_json, request_hash) in &updates {
501 exec((response_json.as_str(), request_hash.as_str()))?;
502 }
503 Ok(())
504 })?;
505 log::info!("Downloaded {} successful requests", success_count);
506 }
507 }
508
509 Ok(())
510 }
511
512 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
513 const BATCH_CHUNK_SIZE: i32 = 16_000;
514 let mut all_batch_ids = Vec::new();
515 let mut total_uploaded = 0;
516
517 loop {
518 let rows: Vec<(String, String)> = {
519 let connection = self.connection.lock().unwrap();
520 let q = sql!(
521 SELECT request_hash, request FROM cache
522 WHERE batch_id IS NULL AND response IS NULL
523 LIMIT ?
524 );
525 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
526 };
527
528 if rows.is_empty() {
529 break;
530 }
531
532 let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
533
534 let batch_requests = rows
535 .iter()
536 .map(|(hash, request_str)| {
537 let serializable_request: SerializableRequest =
538 serde_json::from_str(&request_str).unwrap();
539
540 let messages: Vec<Message> = serializable_request
541 .messages
542 .into_iter()
543 .map(|msg| Message {
544 role: match msg.role.as_str() {
545 "user" => Role::User,
546 "assistant" => Role::Assistant,
547 _ => Role::User,
548 },
549 content: vec![RequestContent::Text {
550 text: msg.content,
551 cache_control: None,
552 }],
553 })
554 .collect();
555
556 let params = AnthropicRequest {
557 model: serializable_request.model,
558 max_tokens: serializable_request.max_tokens,
559 messages,
560 tools: Vec::new(),
561 thinking: None,
562 tool_choice: None,
563 system: None,
564 metadata: None,
565 output_config: None,
566 stop_sequences: Vec::new(),
567 temperature: None,
568 top_k: None,
569 top_p: None,
570 };
571
572 let custom_id = format!("req_hash_{}", hash);
573 anthropic::batches::BatchRequest { custom_id, params }
574 })
575 .collect::<Vec<_>>();
576
577 let batch_len = batch_requests.len();
578 let batch = anthropic::batches::create_batch(
579 self.http_client.as_ref(),
580 ANTHROPIC_API_URL,
581 &self.api_key,
582 anthropic::batches::CreateBatchRequest {
583 requests: batch_requests,
584 },
585 )
586 .await
587 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
588
589 {
590 let connection = self.connection.lock().unwrap();
591 connection.with_savepoint("batch_upload", || {
592 let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
593 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
594 for hash in &request_hashes {
595 exec((batch.id.as_str(), hash.as_str()))?;
596 }
597 Ok(())
598 })?;
599 }
600
601 total_uploaded += batch_len;
602 log::info!(
603 "Uploaded batch {} with {} requests ({} total)",
604 batch.id,
605 batch_len,
606 total_uploaded
607 );
608
609 all_batch_ids.push(batch.id);
610 }
611
612 if !all_batch_ids.is_empty() {
613 log::info!(
614 "Finished uploading {} batches with {} total requests",
615 all_batch_ids.len(),
616 total_uploaded
617 );
618 }
619
620 Ok(all_batch_ids)
621 }
622
623 fn request_hash(
624 model: &str,
625 max_tokens: u64,
626 messages: &[Message],
627 seed: Option<usize>,
628 ) -> String {
629 let mut hasher = std::hash::DefaultHasher::new();
630 model.hash(&mut hasher);
631 max_tokens.hash(&mut hasher);
632 for msg in messages {
633 message_content_to_string(&msg.content).hash(&mut hasher);
634 }
635 if let Some(seed) = seed {
636 seed.hash(&mut hasher);
637 }
638 let request_hash = hasher.finish();
639 format!("{request_hash:016x}")
640 }
641}
642
643fn message_content_to_string(content: &[RequestContent]) -> String {
644 content
645 .iter()
646 .filter_map(|c| match c {
647 RequestContent::Text { text, .. } => Some(text.clone()),
648 _ => None,
649 })
650 .collect::<Vec<String>>()
651 .join("\n")
652}
653
654pub enum AnthropicClient {
655 // No batching
656 Plain(PlainLlmClient),
657 Batch(BatchingLlmClient),
658 Dummy,
659}
660
661impl AnthropicClient {
662 pub fn plain() -> Result<Self> {
663 Ok(Self::Plain(PlainLlmClient::new()?))
664 }
665
666 pub fn batch(cache_path: &Path) -> Result<Self> {
667 Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
668 }
669
670 #[allow(dead_code)]
671 pub fn dummy() -> Self {
672 Self::Dummy
673 }
674
675 pub async fn generate(
676 &self,
677 model: &str,
678 max_tokens: u64,
679 messages: Vec<Message>,
680 seed: Option<usize>,
681 cache_only: bool,
682 ) -> Result<Option<AnthropicResponse>> {
683 match self {
684 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
685 .generate(model, max_tokens, messages)
686 .await
687 .map(Some),
688 AnthropicClient::Batch(batching_llm_client) => {
689 batching_llm_client
690 .generate(model, max_tokens, messages, seed, cache_only)
691 .await
692 }
693 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
694 }
695 }
696
697 #[allow(dead_code)]
698 pub async fn generate_streaming<F>(
699 &self,
700 model: &str,
701 max_tokens: u64,
702 messages: Vec<Message>,
703 on_progress: F,
704 ) -> Result<Option<AnthropicResponse>>
705 where
706 F: FnMut(usize, &str),
707 {
708 match self {
709 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
710 .generate_streaming(model, max_tokens, messages, on_progress)
711 .await
712 .map(Some),
713 AnthropicClient::Batch(_) => {
714 anyhow::bail!("Streaming not supported with batching client")
715 }
716 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
717 }
718 }
719
720 pub async fn sync_batches(&self) -> Result<()> {
721 match self {
722 AnthropicClient::Plain(_) => Ok(()),
723 AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
724 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
725 }
726 }
727
728 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
729 match self {
730 AnthropicClient::Plain(_) => {
731 anyhow::bail!("Import batches is only supported with batching client")
732 }
733 AnthropicClient::Batch(batching_llm_client) => {
734 batching_llm_client.import_batches(batch_ids).await
735 }
736 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
737 }
738 }
739}