1use anthropic::{
2 ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
3 Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
4 stream_completion,
5};
6use anyhow::Result;
7use futures::StreamExt as _;
8use http_client::HttpClient;
9use indoc::indoc;
10use reqwest_client::ReqwestClient;
11use sqlez::bindable::Bind;
12use sqlez::bindable::StaticColumnCount;
13use sqlez_macros::sql;
14use std::hash::Hash;
15use std::hash::Hasher;
16use std::path::Path;
17use std::sync::{Arc, Mutex};
18
19pub struct PlainLlmClient {
20 pub http_client: Arc<dyn HttpClient>,
21 pub api_key: String,
22}
23
24impl PlainLlmClient {
25 pub fn new() -> Result<Self> {
26 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
27 let api_key = std::env::var("ANTHROPIC_API_KEY")
28 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
29 Ok(Self {
30 http_client,
31 api_key,
32 })
33 }
34
35 pub async fn generate(
36 &self,
37 model: &str,
38 max_tokens: u64,
39 messages: Vec<Message>,
40 ) -> Result<AnthropicResponse> {
41 let request = AnthropicRequest {
42 model: model.to_string(),
43 max_tokens,
44 messages,
45 tools: Vec::new(),
46 thinking: None,
47 tool_choice: None,
48 system: None,
49 metadata: None,
50 stop_sequences: Vec::new(),
51 temperature: None,
52 top_k: None,
53 top_p: None,
54 };
55
56 let response = non_streaming_completion(
57 self.http_client.as_ref(),
58 ANTHROPIC_API_URL,
59 &self.api_key,
60 request,
61 None,
62 )
63 .await
64 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
65
66 Ok(response)
67 }
68
69 pub async fn generate_streaming<F>(
70 &self,
71 model: &str,
72 max_tokens: u64,
73 messages: Vec<Message>,
74 mut on_progress: F,
75 ) -> Result<AnthropicResponse>
76 where
77 F: FnMut(usize, &str),
78 {
79 let request = AnthropicRequest {
80 model: model.to_string(),
81 max_tokens,
82 messages,
83 tools: Vec::new(),
84 thinking: None,
85 tool_choice: None,
86 system: None,
87 metadata: None,
88 stop_sequences: Vec::new(),
89 temperature: None,
90 top_k: None,
91 top_p: None,
92 };
93
94 let mut stream = stream_completion(
95 self.http_client.as_ref(),
96 ANTHROPIC_API_URL,
97 &self.api_key,
98 request,
99 None,
100 )
101 .await
102 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
103
104 let mut response: Option<AnthropicResponse> = None;
105 let mut text_content = String::new();
106
107 while let Some(event_result) = stream.next().await {
108 let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
109
110 match event {
111 Event::MessageStart { message } => {
112 response = Some(message);
113 }
114 Event::ContentBlockDelta { delta, .. } => {
115 if let anthropic::ContentDelta::TextDelta { text } = delta {
116 text_content.push_str(&text);
117 on_progress(text_content.len(), &text_content);
118 }
119 }
120 _ => {}
121 }
122 }
123
124 let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
125
126 if response.content.is_empty() && !text_content.is_empty() {
127 response
128 .content
129 .push(ResponseContent::Text { text: text_content });
130 }
131
132 Ok(response)
133 }
134}
135
136pub struct BatchingLlmClient {
137 connection: Mutex<sqlez::connection::Connection>,
138 http_client: Arc<dyn HttpClient>,
139 api_key: String,
140}
141
142struct CacheRow {
143 request_hash: String,
144 request: Option<String>,
145 response: Option<String>,
146 batch_id: Option<String>,
147}
148
149impl StaticColumnCount for CacheRow {
150 fn column_count() -> usize {
151 4
152 }
153}
154
155impl Bind for CacheRow {
156 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
157 let next_index = statement.bind(&self.request_hash, start_index)?;
158 let next_index = statement.bind(&self.request, next_index)?;
159 let next_index = statement.bind(&self.response, next_index)?;
160 let next_index = statement.bind(&self.batch_id, next_index)?;
161 Ok(next_index)
162 }
163}
164
165#[derive(serde::Serialize, serde::Deserialize)]
166struct SerializableRequest {
167 model: String,
168 max_tokens: u64,
169 messages: Vec<SerializableMessage>,
170}
171
172#[derive(serde::Serialize, serde::Deserialize)]
173struct SerializableMessage {
174 role: String,
175 content: String,
176}
177
178impl BatchingLlmClient {
179 fn new(cache_path: &Path) -> Result<Self> {
180 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
181 let api_key = std::env::var("ANTHROPIC_API_KEY")
182 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
183
184 let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
185 let mut statement = sqlez::statement::Statement::prepare(
186 &connection,
187 indoc! {"
188 CREATE TABLE IF NOT EXISTS cache (
189 request_hash TEXT PRIMARY KEY,
190 request TEXT,
191 response TEXT,
192 batch_id TEXT
193 );
194 "},
195 )?;
196 statement.exec()?;
197 drop(statement);
198
199 Ok(Self {
200 connection: Mutex::new(connection),
201 http_client,
202 api_key,
203 })
204 }
205
206 pub fn lookup(
207 &self,
208 model: &str,
209 max_tokens: u64,
210 messages: &[Message],
211 seed: Option<usize>,
212 ) -> Result<Option<AnthropicResponse>> {
213 let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
214 let connection = self.connection.lock().unwrap();
215 let response: Vec<String> = connection.select_bound(
216 &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
217 )?(request_hash_str.as_str())?;
218 Ok(response
219 .into_iter()
220 .next()
221 .and_then(|text| serde_json::from_str(&text).ok()))
222 }
223
224 pub fn mark_for_batch(
225 &self,
226 model: &str,
227 max_tokens: u64,
228 messages: &[Message],
229 seed: Option<usize>,
230 ) -> Result<()> {
231 let request_hash = Self::request_hash(model, max_tokens, messages, seed);
232
233 let serializable_messages: Vec<SerializableMessage> = messages
234 .iter()
235 .map(|msg| SerializableMessage {
236 role: match msg.role {
237 Role::User => "user".to_string(),
238 Role::Assistant => "assistant".to_string(),
239 },
240 content: message_content_to_string(&msg.content),
241 })
242 .collect();
243
244 let serializable_request = SerializableRequest {
245 model: model.to_string(),
246 max_tokens,
247 messages: serializable_messages,
248 };
249
250 let request = Some(serde_json::to_string(&serializable_request)?);
251 let cache_row = CacheRow {
252 request_hash,
253 request,
254 response: None,
255 batch_id: None,
256 };
257 let connection = self.connection.lock().unwrap();
258 connection.exec_bound::<CacheRow>(sql!(
259 INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
260 cache_row,
261 )
262 }
263
264 async fn generate(
265 &self,
266 model: &str,
267 max_tokens: u64,
268 messages: Vec<Message>,
269 seed: Option<usize>,
270 ) -> Result<Option<AnthropicResponse>> {
271 let response = self.lookup(model, max_tokens, &messages, seed)?;
272 if let Some(response) = response {
273 return Ok(Some(response));
274 }
275
276 self.mark_for_batch(model, max_tokens, &messages, seed)?;
277
278 Ok(None)
279 }
280
281 /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
282 async fn sync_batches(&self) -> Result<()> {
283 let _batch_ids = self.upload_pending_requests().await?;
284 self.download_finished_batches().await
285 }
286
287 /// Import batch results from external batch IDs (useful for recovering after database loss)
288 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
289 for batch_id in batch_ids {
290 log::info!("Importing batch {}", batch_id);
291
292 let batch_status = anthropic::batches::retrieve_batch(
293 self.http_client.as_ref(),
294 ANTHROPIC_API_URL,
295 &self.api_key,
296 batch_id,
297 )
298 .await
299 .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
300
301 log::info!(
302 "Batch {} status: {}",
303 batch_id,
304 batch_status.processing_status
305 );
306
307 if batch_status.processing_status != "ended" {
308 log::warn!(
309 "Batch {} is not finished (status: {}), skipping",
310 batch_id,
311 batch_status.processing_status
312 );
313 continue;
314 }
315
316 let results = anthropic::batches::retrieve_batch_results(
317 self.http_client.as_ref(),
318 ANTHROPIC_API_URL,
319 &self.api_key,
320 batch_id,
321 )
322 .await
323 .map_err(|e| {
324 anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
325 })?;
326
327 let mut updates: Vec<(String, String, String)> = Vec::new();
328 let mut success_count = 0;
329 let mut error_count = 0;
330
331 for result in results {
332 let request_hash = result
333 .custom_id
334 .strip_prefix("req_hash_")
335 .unwrap_or(&result.custom_id)
336 .to_string();
337
338 match result.result {
339 anthropic::batches::BatchResult::Succeeded { message } => {
340 let response_json = serde_json::to_string(&message)?;
341 updates.push((request_hash, response_json, batch_id.clone()));
342 success_count += 1;
343 }
344 anthropic::batches::BatchResult::Errored { error } => {
345 log::error!(
346 "Batch request {} failed: {}: {}",
347 request_hash,
348 error.error.error_type,
349 error.error.message
350 );
351 let error_json = serde_json::json!({
352 "error": {
353 "type": error.error.error_type,
354 "message": error.error.message
355 }
356 })
357 .to_string();
358 updates.push((request_hash, error_json, batch_id.clone()));
359 error_count += 1;
360 }
361 anthropic::batches::BatchResult::Canceled => {
362 log::warn!("Batch request {} was canceled", request_hash);
363 error_count += 1;
364 }
365 anthropic::batches::BatchResult::Expired => {
366 log::warn!("Batch request {} expired", request_hash);
367 error_count += 1;
368 }
369 }
370 }
371
372 let connection = self.connection.lock().unwrap();
373 connection.with_savepoint("batch_import", || {
374 // Use INSERT OR REPLACE to handle both new entries and updating existing ones
375 let q = sql!(
376 INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
377 VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
378 );
379 let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
380 for (request_hash, response_json, batch_id) in &updates {
381 exec((
382 request_hash.as_str(),
383 request_hash.as_str(),
384 response_json.as_str(),
385 batch_id.as_str(),
386 ))?;
387 }
388 Ok(())
389 })?;
390
391 log::info!(
392 "Imported batch {}: {} successful, {} errors",
393 batch_id,
394 success_count,
395 error_count
396 );
397 }
398
399 Ok(())
400 }
401
402 async fn download_finished_batches(&self) -> Result<()> {
403 let batch_ids: Vec<String> = {
404 let connection = self.connection.lock().unwrap();
405 let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
406 connection.select(q)?()?
407 };
408
409 for batch_id in &batch_ids {
410 let batch_status = anthropic::batches::retrieve_batch(
411 self.http_client.as_ref(),
412 ANTHROPIC_API_URL,
413 &self.api_key,
414 &batch_id,
415 )
416 .await
417 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
418
419 log::info!(
420 "Batch {} status: {}",
421 batch_id,
422 batch_status.processing_status
423 );
424
425 if batch_status.processing_status == "ended" {
426 let results = anthropic::batches::retrieve_batch_results(
427 self.http_client.as_ref(),
428 ANTHROPIC_API_URL,
429 &self.api_key,
430 &batch_id,
431 )
432 .await
433 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
434
435 let mut updates: Vec<(String, String)> = Vec::new();
436 let mut success_count = 0;
437 for result in results {
438 let request_hash = result
439 .custom_id
440 .strip_prefix("req_hash_")
441 .unwrap_or(&result.custom_id)
442 .to_string();
443
444 match result.result {
445 anthropic::batches::BatchResult::Succeeded { message } => {
446 let response_json = serde_json::to_string(&message)?;
447 updates.push((response_json, request_hash));
448 success_count += 1;
449 }
450 anthropic::batches::BatchResult::Errored { error } => {
451 log::error!(
452 "Batch request {} failed: {}: {}",
453 request_hash,
454 error.error.error_type,
455 error.error.message
456 );
457 let error_json = serde_json::json!({
458 "error": {
459 "type": error.error.error_type,
460 "message": error.error.message
461 }
462 })
463 .to_string();
464 updates.push((error_json, request_hash));
465 }
466 anthropic::batches::BatchResult::Canceled => {
467 log::warn!("Batch request {} was canceled", request_hash);
468 let error_json = serde_json::json!({
469 "error": {
470 "type": "canceled",
471 "message": "Batch request was canceled"
472 }
473 })
474 .to_string();
475 updates.push((error_json, request_hash));
476 }
477 anthropic::batches::BatchResult::Expired => {
478 log::warn!("Batch request {} expired", request_hash);
479 let error_json = serde_json::json!({
480 "error": {
481 "type": "expired",
482 "message": "Batch request expired"
483 }
484 })
485 .to_string();
486 updates.push((error_json, request_hash));
487 }
488 }
489 }
490
491 let connection = self.connection.lock().unwrap();
492 connection.with_savepoint("batch_download", || {
493 let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
494 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
495 for (response_json, request_hash) in &updates {
496 exec((response_json.as_str(), request_hash.as_str()))?;
497 }
498 Ok(())
499 })?;
500 log::info!("Downloaded {} successful requests", success_count);
501 }
502 }
503
504 Ok(())
505 }
506
507 async fn upload_pending_requests(&self) -> Result<Vec<String>> {
508 const BATCH_CHUNK_SIZE: i32 = 16_000;
509 let mut all_batch_ids = Vec::new();
510 let mut total_uploaded = 0;
511
512 loop {
513 let rows: Vec<(String, String)> = {
514 let connection = self.connection.lock().unwrap();
515 let q = sql!(
516 SELECT request_hash, request FROM cache
517 WHERE batch_id IS NULL AND response IS NULL
518 LIMIT ?
519 );
520 connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
521 };
522
523 if rows.is_empty() {
524 break;
525 }
526
527 let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
528
529 let batch_requests = rows
530 .iter()
531 .map(|(hash, request_str)| {
532 let serializable_request: SerializableRequest =
533 serde_json::from_str(&request_str).unwrap();
534
535 let messages: Vec<Message> = serializable_request
536 .messages
537 .into_iter()
538 .map(|msg| Message {
539 role: match msg.role.as_str() {
540 "user" => Role::User,
541 "assistant" => Role::Assistant,
542 _ => Role::User,
543 },
544 content: vec![RequestContent::Text {
545 text: msg.content,
546 cache_control: None,
547 }],
548 })
549 .collect();
550
551 let params = AnthropicRequest {
552 model: serializable_request.model,
553 max_tokens: serializable_request.max_tokens,
554 messages,
555 tools: Vec::new(),
556 thinking: None,
557 tool_choice: None,
558 system: None,
559 metadata: None,
560 stop_sequences: Vec::new(),
561 temperature: None,
562 top_k: None,
563 top_p: None,
564 };
565
566 let custom_id = format!("req_hash_{}", hash);
567 anthropic::batches::BatchRequest { custom_id, params }
568 })
569 .collect::<Vec<_>>();
570
571 let batch_len = batch_requests.len();
572 let batch = anthropic::batches::create_batch(
573 self.http_client.as_ref(),
574 ANTHROPIC_API_URL,
575 &self.api_key,
576 anthropic::batches::CreateBatchRequest {
577 requests: batch_requests,
578 },
579 )
580 .await
581 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
582
583 {
584 let connection = self.connection.lock().unwrap();
585 connection.with_savepoint("batch_upload", || {
586 let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
587 let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
588 for hash in &request_hashes {
589 exec((batch.id.as_str(), hash.as_str()))?;
590 }
591 Ok(())
592 })?;
593 }
594
595 total_uploaded += batch_len;
596 log::info!(
597 "Uploaded batch {} with {} requests ({} total)",
598 batch.id,
599 batch_len,
600 total_uploaded
601 );
602
603 all_batch_ids.push(batch.id);
604 }
605
606 if !all_batch_ids.is_empty() {
607 log::info!(
608 "Finished uploading {} batches with {} total requests",
609 all_batch_ids.len(),
610 total_uploaded
611 );
612 }
613
614 Ok(all_batch_ids)
615 }
616
617 fn request_hash(
618 model: &str,
619 max_tokens: u64,
620 messages: &[Message],
621 seed: Option<usize>,
622 ) -> String {
623 let mut hasher = std::hash::DefaultHasher::new();
624 model.hash(&mut hasher);
625 max_tokens.hash(&mut hasher);
626 for msg in messages {
627 message_content_to_string(&msg.content).hash(&mut hasher);
628 }
629 if let Some(seed) = seed {
630 seed.hash(&mut hasher);
631 }
632 let request_hash = hasher.finish();
633 format!("{request_hash:016x}")
634 }
635}
636
637fn message_content_to_string(content: &[RequestContent]) -> String {
638 content
639 .iter()
640 .filter_map(|c| match c {
641 RequestContent::Text { text, .. } => Some(text.clone()),
642 _ => None,
643 })
644 .collect::<Vec<String>>()
645 .join("\n")
646}
647
648pub enum AnthropicClient {
649 // No batching
650 Plain(PlainLlmClient),
651 Batch(BatchingLlmClient),
652 Dummy,
653}
654
655impl AnthropicClient {
656 pub fn plain() -> Result<Self> {
657 Ok(Self::Plain(PlainLlmClient::new()?))
658 }
659
660 pub fn batch(cache_path: &Path) -> Result<Self> {
661 Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
662 }
663
664 #[allow(dead_code)]
665 pub fn dummy() -> Self {
666 Self::Dummy
667 }
668
669 pub async fn generate(
670 &self,
671 model: &str,
672 max_tokens: u64,
673 messages: Vec<Message>,
674 seed: Option<usize>,
675 ) -> Result<Option<AnthropicResponse>> {
676 match self {
677 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
678 .generate(model, max_tokens, messages)
679 .await
680 .map(Some),
681 AnthropicClient::Batch(batching_llm_client) => {
682 batching_llm_client
683 .generate(model, max_tokens, messages, seed)
684 .await
685 }
686 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
687 }
688 }
689
690 #[allow(dead_code)]
691 pub async fn generate_streaming<F>(
692 &self,
693 model: &str,
694 max_tokens: u64,
695 messages: Vec<Message>,
696 on_progress: F,
697 ) -> Result<Option<AnthropicResponse>>
698 where
699 F: FnMut(usize, &str),
700 {
701 match self {
702 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
703 .generate_streaming(model, max_tokens, messages, on_progress)
704 .await
705 .map(Some),
706 AnthropicClient::Batch(_) => {
707 anyhow::bail!("Streaming not supported with batching client")
708 }
709 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
710 }
711 }
712
713 pub async fn sync_batches(&self) -> Result<()> {
714 match self {
715 AnthropicClient::Plain(_) => Ok(()),
716 AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
717 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
718 }
719 }
720
721 pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
722 match self {
723 AnthropicClient::Plain(_) => {
724 anyhow::bail!("Import batches is only supported with batching client")
725 }
726 AnthropicClient::Batch(batching_llm_client) => {
727 batching_llm_client.import_batches(batch_ids).await
728 }
729 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
730 }
731 }
732}