1use anthropic::{
2 ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
3 Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
4 stream_completion,
5};
6use anyhow::Result;
7use futures::StreamExt as _;
8use http_client::HttpClient;
9use indoc::indoc;
10use reqwest_client::ReqwestClient;
11use sqlez::bindable::Bind;
12use sqlez::bindable::StaticColumnCount;
13use sqlez_macros::sql;
14use std::hash::Hash;
15use std::hash::Hasher;
16use std::path::Path;
17use std::sync::{Arc, Mutex};
18
19pub struct PlainLlmClient {
20 pub http_client: Arc<dyn HttpClient>,
21 pub api_key: String,
22}
23
24impl PlainLlmClient {
25 pub fn new() -> Result<Self> {
26 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
27 let api_key = std::env::var("ANTHROPIC_API_KEY")
28 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
29 Ok(Self {
30 http_client,
31 api_key,
32 })
33 }
34
35 pub async fn generate(
36 &self,
37 model: &str,
38 max_tokens: u64,
39 messages: Vec<Message>,
40 ) -> Result<AnthropicResponse> {
41 let request = AnthropicRequest {
42 model: model.to_string(),
43 max_tokens,
44 messages,
45 tools: Vec::new(),
46 thinking: None,
47 tool_choice: None,
48 system: None,
49 metadata: None,
50 stop_sequences: Vec::new(),
51 temperature: None,
52 top_k: None,
53 top_p: None,
54 };
55
56 let response = non_streaming_completion(
57 self.http_client.as_ref(),
58 ANTHROPIC_API_URL,
59 &self.api_key,
60 request,
61 None,
62 )
63 .await
64 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
65
66 Ok(response)
67 }
68
69 pub async fn generate_streaming<F>(
70 &self,
71 model: &str,
72 max_tokens: u64,
73 messages: Vec<Message>,
74 mut on_progress: F,
75 ) -> Result<AnthropicResponse>
76 where
77 F: FnMut(usize, &str),
78 {
79 let request = AnthropicRequest {
80 model: model.to_string(),
81 max_tokens,
82 messages,
83 tools: Vec::new(),
84 thinking: None,
85 tool_choice: None,
86 system: None,
87 metadata: None,
88 stop_sequences: Vec::new(),
89 temperature: None,
90 top_k: None,
91 top_p: None,
92 };
93
94 let mut stream = stream_completion(
95 self.http_client.as_ref(),
96 ANTHROPIC_API_URL,
97 &self.api_key,
98 request,
99 None,
100 )
101 .await
102 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
103
104 let mut response: Option<AnthropicResponse> = None;
105 let mut text_content = String::new();
106
107 while let Some(event_result) = stream.next().await {
108 let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
109
110 match event {
111 Event::MessageStart { message } => {
112 response = Some(message);
113 }
114 Event::ContentBlockDelta { delta, .. } => {
115 if let anthropic::ContentDelta::TextDelta { text } = delta {
116 text_content.push_str(&text);
117 on_progress(text_content.len(), &text_content);
118 }
119 }
120 _ => {}
121 }
122 }
123
124 let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
125
126 if response.content.is_empty() && !text_content.is_empty() {
127 response
128 .content
129 .push(ResponseContent::Text { text: text_content });
130 }
131
132 Ok(response)
133 }
134}
135
136pub struct BatchingLlmClient {
137 connection: Mutex<sqlez::connection::Connection>,
138 http_client: Arc<dyn HttpClient>,
139 api_key: String,
140}
141
142struct CacheRow {
143 request_hash: String,
144 request: Option<String>,
145 response: Option<String>,
146 batch_id: Option<String>,
147}
148
149impl StaticColumnCount for CacheRow {
150 fn column_count() -> usize {
151 4
152 }
153}
154
155impl Bind for CacheRow {
156 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
157 let next_index = statement.bind(&self.request_hash, start_index)?;
158 let next_index = statement.bind(&self.request, next_index)?;
159 let next_index = statement.bind(&self.response, next_index)?;
160 let next_index = statement.bind(&self.batch_id, next_index)?;
161 Ok(next_index)
162 }
163}
164
165#[derive(serde::Serialize, serde::Deserialize)]
166struct SerializableRequest {
167 model: String,
168 max_tokens: u64,
169 messages: Vec<SerializableMessage>,
170}
171
172#[derive(serde::Serialize, serde::Deserialize)]
173struct SerializableMessage {
174 role: String,
175 content: String,
176}
177
178impl BatchingLlmClient {
179 fn new(cache_path: &Path) -> Result<Self> {
180 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
181 let api_key = std::env::var("ANTHROPIC_API_KEY")
182 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
183
184 let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
185 let mut statement = sqlez::statement::Statement::prepare(
186 &connection,
187 indoc! {"
188 CREATE TABLE IF NOT EXISTS cache (
189 request_hash TEXT PRIMARY KEY,
190 request TEXT,
191 response TEXT,
192 batch_id TEXT
193 );
194 "},
195 )?;
196 statement.exec()?;
197 drop(statement);
198
199 Ok(Self {
200 connection: Mutex::new(connection),
201 http_client,
202 api_key,
203 })
204 }
205
206 pub fn lookup(
207 &self,
208 model: &str,
209 max_tokens: u64,
210 messages: &[Message],
211 ) -> Result<Option<AnthropicResponse>> {
212 let request_hash_str = Self::request_hash(model, max_tokens, messages);
213 let connection = self.connection.lock().unwrap();
214 let response: Vec<String> = connection.select_bound(
215 &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
216 )?(request_hash_str.as_str())?;
217 Ok(response
218 .into_iter()
219 .next()
220 .and_then(|text| serde_json::from_str(&text).ok()))
221 }
222
223 pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
224 let request_hash = Self::request_hash(model, max_tokens, messages);
225
226 let serializable_messages: Vec<SerializableMessage> = messages
227 .iter()
228 .map(|msg| SerializableMessage {
229 role: match msg.role {
230 Role::User => "user".to_string(),
231 Role::Assistant => "assistant".to_string(),
232 },
233 content: message_content_to_string(&msg.content),
234 })
235 .collect();
236
237 let serializable_request = SerializableRequest {
238 model: model.to_string(),
239 max_tokens,
240 messages: serializable_messages,
241 };
242
243 let request = Some(serde_json::to_string(&serializable_request)?);
244 let cache_row = CacheRow {
245 request_hash,
246 request,
247 response: None,
248 batch_id: None,
249 };
250 let connection = self.connection.lock().unwrap();
251 connection.exec_bound::<CacheRow>(sql!(
252 INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
253 cache_row,
254 )
255 }
256
257 async fn generate(
258 &self,
259 model: &str,
260 max_tokens: u64,
261 messages: Vec<Message>,
262 ) -> Result<Option<AnthropicResponse>> {
263 let response = self.lookup(model, max_tokens, &messages)?;
264 if let Some(response) = response {
265 return Ok(Some(response));
266 }
267
268 self.mark_for_batch(model, max_tokens, &messages)?;
269
270 Ok(None)
271 }
272
273 /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
274 async fn sync_batches(&self) -> Result<()> {
275 let _batch_ids = self.upload_pending_requests().await?;
276 self.download_finished_batches().await
277 }
278
279 /// Import batch results from external batch IDs (useful for recovering after database loss)
280 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
281 for batch_id in batch_ids {
282 log::info!("Importing batch {}", batch_id);
283
284 let batch_status = anthropic::batches::retrieve_batch(
285 self.http_client.as_ref(),
286 ANTHROPIC_API_URL,
287 &self.api_key,
288 batch_id,
289 )
290 .await
291 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
292
293 log::info!(
294 "Batch {} status: {}",
295 batch_id,
296 batch_status.processing_status
297 );
298
299 if batch_status.processing_status != "ended" {
300 log::warn!(
301 "Batch {} is not finished (status: {}), skipping",
302 batch_id,
303 batch_status.processing_status
304 );
305 continue;
306 }
307
308 let results = anthropic::batches::retrieve_batch_results(
309 self.http_client.as_ref(),
310 ANTHROPIC_API_URL,
311 &self.api_key,
312 batch_id,
313 )
314 .await
315 .map_err(|e| {
316 anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
317 })?;
318
319 let mut updates: Vec<(String, String, String)> = Vec::new();
320 let mut success_count = 0;
321 let mut error_count = 0;
322
323 for result in results {
324 let request_hash = result
325 .custom_id
326 .strip_prefix("req_hash_")
327 .unwrap_or(&result.custom_id)
328 .to_string();
329
330 match result.result {
331 anthropic::batches::BatchResult::Succeeded { message } => {
332 let response_json = serde_json::to_string(&message)?;
333 updates.push((request_hash, response_json, batch_id.clone()));
334 success_count += 1;
335 }
336 anthropic::batches::BatchResult::Errored { error } => {
337 log::error!(
338 "Batch request {} failed: {}: {}",
339 request_hash,
340 error.error.error_type,
341 error.error.message
342 );
343 let error_json = serde_json::json!({
344 "error": {
345 "type": error.error.error_type,
346 "message": error.error.message
347 }
348 })
349 .to_string();
350 updates.push((request_hash, error_json, batch_id.clone()));
351 error_count += 1;
352 }
353 anthropic::batches::BatchResult::Canceled => {
354 log::warn!("Batch request {} was canceled", request_hash);
355 error_count += 1;
356 }
357 anthropic::batches::BatchResult::Expired => {
358 log::warn!("Batch request {} expired", request_hash);
359 error_count += 1;
360 }
361 }
362 }
363
364 let connection = self.connection.lock().unwrap();
365 connection.with_savepoint("batch_import", || {
366 // Use INSERT OR REPLACE to handle both new entries and updating existing ones
367 let q = sql!(
368 INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
369 VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
370 );
371 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
372 for (request_hash, response_json, batch_id) in &updates {
373 exec((
374 request_hash.as_str(),
375 request_hash.as_str(),
376 response_json.as_str(),
377 batch_id.as_str(),
378 ))?;
379 }
380 Ok(())
381 })?;
382
383 log::info!(
384 "Imported batch {}: {} successful, {} errors",
385 batch_id,
386 success_count,
387 error_count
388 );
389 }
390
391 Ok(())
392 }
393
394 async fn download_finished_batches(&self) -> Result<()> {
395 let batch_ids: Vec<String> = {
396 let connection = self.connection.lock().unwrap();
397 let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
398 connection.select(q)?()?
399 };
400
401 for batch_id in &batch_ids {
402 let batch_status = anthropic::batches::retrieve_batch(
403 self.http_client.as_ref(),
404 ANTHROPIC_API_URL,
405 &self.api_key,
406 &batch_id,
407 )
408 .await
409 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
410
411 log::info!(
412 "Batch {} status: {}",
413 batch_id,
414 batch_status.processing_status
415 );
416
417 if batch_status.processing_status == "ended" {
418 let results = anthropic::batches::retrieve_batch_results(
419 self.http_client.as_ref(),
420 ANTHROPIC_API_URL,
421 &self.api_key,
422 &batch_id,
423 )
424 .await
425 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
426
427 let mut updates: Vec<(String, String)> = Vec::new();
428 let mut success_count = 0;
429 for result in results {
430 let request_hash = result
431 .custom_id
432 .strip_prefix("req_hash_")
433 .unwrap_or(&result.custom_id)
434 .to_string();
435
436 match result.result {
437 anthropic::batches::BatchResult::Succeeded { message } => {
438 let response_json = serde_json::to_string(&message)?;
439 updates.push((response_json, request_hash));
440 success_count += 1;
441 }
442 anthropic::batches::BatchResult::Errored { error } => {
443 log::error!(
444 "Batch request {} failed: {}: {}",
445 request_hash,
446 error.error.error_type,
447 error.error.message
448 );
449 let error_json = serde_json::json!({
450 "error": {
451 "type": error.error.error_type,
452 "message": error.error.message
453 }
454 })
455 .to_string();
456 updates.push((error_json, request_hash));
457 }
458 anthropic::batches::BatchResult::Canceled => {
459 log::warn!("Batch request {} was canceled", request_hash);
460 let error_json = serde_json::json!({
461 "error": {
462 "type": "canceled",
463 "message": "Batch request was canceled"
464 }
465 })
466 .to_string();
467 updates.push((error_json, request_hash));
468 }
469 anthropic::batches::BatchResult::Expired => {
470 log::warn!("Batch request {} expired", request_hash);
471 let error_json = serde_json::json!({
472 "error": {
473 "type": "expired",
474 "message": "Batch request expired"
475 }
476 })
477 .to_string();
478 updates.push((error_json, request_hash));
479 }
480 }
481 }
482
483 let connection = self.connection.lock().unwrap();
484 connection.with_savepoint("batch_download", || {
485 let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
486 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
487 for (response_json, request_hash) in &updates {
488 exec((response_json.as_str(), request_hash.as_str()))?;
489 }
490 Ok(())
491 })?;
492 log::info!("Downloaded {} successful requests", success_count);
493 }
494 }
495
496 Ok(())
497 }
498
499 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
500 const BATCH_CHUNK_SIZE: i32 = 16_000;
501 let mut all_batch_ids = Vec::new();
502 let mut total_uploaded = 0;
503
504 loop {
505 let rows: Vec<(String, String)> = {
506 let connection = self.connection.lock().unwrap();
507 let q = sql!(
508 SELECT request_hash, request FROM cache
509 WHERE batch_id IS NULL AND response IS NULL
510 LIMIT ?
511 );
512 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
513 };
514
515 if rows.is_empty() {
516 break;
517 }
518
519 let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
520
521 let batch_requests = rows
522 .iter()
523 .map(|(hash, request_str)| {
524 let serializable_request: SerializableRequest =
525 serde_json::from_str(&request_str).unwrap();
526
527 let messages: Vec<Message> = serializable_request
528 .messages
529 .into_iter()
530 .map(|msg| Message {
531 role: match msg.role.as_str() {
532 "user" => Role::User,
533 "assistant" => Role::Assistant,
534 _ => Role::User,
535 },
536 content: vec![RequestContent::Text {
537 text: msg.content,
538 cache_control: None,
539 }],
540 })
541 .collect();
542
543 let params = AnthropicRequest {
544 model: serializable_request.model,
545 max_tokens: serializable_request.max_tokens,
546 messages,
547 tools: Vec::new(),
548 thinking: None,
549 tool_choice: None,
550 system: None,
551 metadata: None,
552 stop_sequences: Vec::new(),
553 temperature: None,
554 top_k: None,
555 top_p: None,
556 };
557
558 let custom_id = format!("req_hash_{}", hash);
559 anthropic::batches::BatchRequest { custom_id, params }
560 })
561 .collect::<Vec<_>>();
562
563 let batch_len = batch_requests.len();
564 let batch = anthropic::batches::create_batch(
565 self.http_client.as_ref(),
566 ANTHROPIC_API_URL,
567 &self.api_key,
568 anthropic::batches::CreateBatchRequest {
569 requests: batch_requests,
570 },
571 )
572 .await
573 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
574
575 {
576 let connection = self.connection.lock().unwrap();
577 connection.with_savepoint("batch_upload", || {
578 let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
579 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
580 for hash in &request_hashes {
581 exec((batch.id.as_str(), hash.as_str()))?;
582 }
583 Ok(())
584 })?;
585 }
586
587 total_uploaded += batch_len;
588 log::info!(
589 "Uploaded batch {} with {} requests ({} total)",
590 batch.id,
591 batch_len,
592 total_uploaded
593 );
594
595 all_batch_ids.push(batch.id);
596 }
597
598 if !all_batch_ids.is_empty() {
599 log::info!(
600 "Finished uploading {} batches with {} total requests",
601 all_batch_ids.len(),
602 total_uploaded
603 );
604 }
605
606 Ok(all_batch_ids)
607 }
608
609 fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
610 let mut hasher = std::hash::DefaultHasher::new();
611 model.hash(&mut hasher);
612 max_tokens.hash(&mut hasher);
613 for msg in messages {
614 message_content_to_string(&msg.content).hash(&mut hasher);
615 }
616 let request_hash = hasher.finish();
617 format!("{request_hash:016x}")
618 }
619}
620
621fn message_content_to_string(content: &[RequestContent]) -> String {
622 content
623 .iter()
624 .filter_map(|c| match c {
625 RequestContent::Text { text, .. } => Some(text.clone()),
626 _ => None,
627 })
628 .collect::<Vec<String>>()
629 .join("\n")
630}
631
632pub enum AnthropicClient {
633 // No batching
634 Plain(PlainLlmClient),
635 Batch(BatchingLlmClient),
636 Dummy,
637}
638
639impl AnthropicClient {
640 pub fn plain() -> Result<Self> {
641 Ok(Self::Plain(PlainLlmClient::new()?))
642 }
643
644 pub fn batch(cache_path: &Path) -> Result<Self> {
645 Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
646 }
647
648 #[allow(dead_code)]
649 pub fn dummy() -> Self {
650 Self::Dummy
651 }
652
653 pub async fn generate(
654 &self,
655 model: &str,
656 max_tokens: u64,
657 messages: Vec<Message>,
658 ) -> Result<Option<AnthropicResponse>> {
659 match self {
660 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
661 .generate(model, max_tokens, messages)
662 .await
663 .map(Some),
664 AnthropicClient::Batch(batching_llm_client) => {
665 batching_llm_client
666 .generate(model, max_tokens, messages)
667 .await
668 }
669 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
670 }
671 }
672
673 #[allow(dead_code)]
674 pub async fn generate_streaming<F>(
675 &self,
676 model: &str,
677 max_tokens: u64,
678 messages: Vec<Message>,
679 on_progress: F,
680 ) -> Result<Option<AnthropicResponse>>
681 where
682 F: FnMut(usize, &str),
683 {
684 match self {
685 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
686 .generate_streaming(model, max_tokens, messages, on_progress)
687 .await
688 .map(Some),
689 AnthropicClient::Batch(_) => {
690 anyhow::bail!("Streaming not supported with batching client")
691 }
692 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
693 }
694 }
695
696 pub async fn sync_batches(&self) -> Result<()> {
697 match self {
698 AnthropicClient::Plain(_) => Ok(()),
699 AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
700 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
701 }
702 }
703
704 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
705 match self {
706 AnthropicClient::Plain(_) => {
707 anyhow::bail!("Import batches is only supported with batching client")
708 }
709 AnthropicClient::Batch(batching_llm_client) => {
710 batching_llm_client.import_batches(batch_ids).await
711 }
712 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
713 }
714 }
715}