1use anthropic::{
2 ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
3 Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
4 stream_completion,
5};
6use anyhow::Result;
7use futures::StreamExt as _;
8use http_client::HttpClient;
9use indoc::indoc;
10use reqwest_client::ReqwestClient;
11use sqlez::bindable::Bind;
12use sqlez::bindable::StaticColumnCount;
13use sqlez_macros::sql;
14use std::hash::Hash;
15use std::hash::Hasher;
16use std::path::Path;
17use std::sync::Arc;
18
19pub struct PlainLlmClient {
20 pub http_client: Arc<dyn HttpClient>,
21 pub api_key: String,
22}
23
24impl PlainLlmClient {
25 pub fn new() -> Result<Self> {
26 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
27 let api_key = std::env::var("ANTHROPIC_API_KEY")
28 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
29 Ok(Self {
30 http_client,
31 api_key,
32 })
33 }
34
35 pub async fn generate(
36 &self,
37 model: &str,
38 max_tokens: u64,
39 messages: Vec<Message>,
40 ) -> Result<AnthropicResponse> {
41 let request = AnthropicRequest {
42 model: model.to_string(),
43 max_tokens,
44 messages,
45 tools: Vec::new(),
46 thinking: None,
47 tool_choice: None,
48 system: None,
49 metadata: None,
50 stop_sequences: Vec::new(),
51 temperature: None,
52 top_k: None,
53 top_p: None,
54 };
55
56 let response = non_streaming_completion(
57 self.http_client.as_ref(),
58 ANTHROPIC_API_URL,
59 &self.api_key,
60 request,
61 None,
62 )
63 .await
64 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
65
66 Ok(response)
67 }
68
69 pub async fn generate_streaming<F>(
70 &self,
71 model: &str,
72 max_tokens: u64,
73 messages: Vec<Message>,
74 mut on_progress: F,
75 ) -> Result<AnthropicResponse>
76 where
77 F: FnMut(usize, &str),
78 {
79 let request = AnthropicRequest {
80 model: model.to_string(),
81 max_tokens,
82 messages,
83 tools: Vec::new(),
84 thinking: None,
85 tool_choice: None,
86 system: None,
87 metadata: None,
88 stop_sequences: Vec::new(),
89 temperature: None,
90 top_k: None,
91 top_p: None,
92 };
93
94 let mut stream = stream_completion(
95 self.http_client.as_ref(),
96 ANTHROPIC_API_URL,
97 &self.api_key,
98 request,
99 None,
100 )
101 .await
102 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
103
104 let mut response: Option<AnthropicResponse> = None;
105 let mut text_content = String::new();
106
107 while let Some(event_result) = stream.next().await {
108 let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
109
110 match event {
111 Event::MessageStart { message } => {
112 response = Some(message);
113 }
114 Event::ContentBlockDelta { delta, .. } => {
115 if let anthropic::ContentDelta::TextDelta { text } = delta {
116 text_content.push_str(&text);
117 on_progress(text_content.len(), &text_content);
118 }
119 }
120 _ => {}
121 }
122 }
123
124 let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
125
126 if response.content.is_empty() && !text_content.is_empty() {
127 response
128 .content
129 .push(ResponseContent::Text { text: text_content });
130 }
131
132 Ok(response)
133 }
134}
135
136pub struct BatchingLlmClient {
137 connection: sqlez::connection::Connection,
138 http_client: Arc<dyn HttpClient>,
139 api_key: String,
140}
141
142struct CacheRow {
143 request_hash: String,
144 request: Option<String>,
145 response: Option<String>,
146 batch_id: Option<String>,
147}
148
149impl StaticColumnCount for CacheRow {
150 fn column_count() -> usize {
151 4
152 }
153}
154
155impl Bind for CacheRow {
156 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
157 let next_index = statement.bind(&self.request_hash, start_index)?;
158 let next_index = statement.bind(&self.request, next_index)?;
159 let next_index = statement.bind(&self.response, next_index)?;
160 let next_index = statement.bind(&self.batch_id, next_index)?;
161 Ok(next_index)
162 }
163}
164
165#[derive(serde::Serialize, serde::Deserialize)]
166struct SerializableRequest {
167 model: String,
168 max_tokens: u64,
169 messages: Vec<SerializableMessage>,
170}
171
172#[derive(serde::Serialize, serde::Deserialize)]
173struct SerializableMessage {
174 role: String,
175 content: String,
176}
177
178impl BatchingLlmClient {
179 fn new(cache_path: &Path) -> Result<Self> {
180 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
181 let api_key = std::env::var("ANTHROPIC_API_KEY")
182 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
183
184 let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
185 let mut statement = sqlez::statement::Statement::prepare(
186 &connection,
187 indoc! {"
188 CREATE TABLE IF NOT EXISTS cache (
189 request_hash TEXT PRIMARY KEY,
190 request TEXT,
191 response TEXT,
192 batch_id TEXT
193 );
194 "},
195 )?;
196 statement.exec()?;
197 drop(statement);
198
199 Ok(Self {
200 connection,
201 http_client,
202 api_key,
203 })
204 }
205
206 pub fn lookup(
207 &self,
208 model: &str,
209 max_tokens: u64,
210 messages: &[Message],
211 ) -> Result<Option<AnthropicResponse>> {
212 let request_hash_str = Self::request_hash(model, max_tokens, messages);
213 let response: Vec<String> = self.connection.select_bound(
214 &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
215 )?(request_hash_str.as_str())?;
216 Ok(response
217 .into_iter()
218 .next()
219 .and_then(|text| serde_json::from_str(&text).ok()))
220 }
221
222 pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
223 let request_hash = Self::request_hash(model, max_tokens, messages);
224
225 let serializable_messages: Vec<SerializableMessage> = messages
226 .iter()
227 .map(|msg| SerializableMessage {
228 role: match msg.role {
229 Role::User => "user".to_string(),
230 Role::Assistant => "assistant".to_string(),
231 },
232 content: message_content_to_string(&msg.content),
233 })
234 .collect();
235
236 let serializable_request = SerializableRequest {
237 model: model.to_string(),
238 max_tokens,
239 messages: serializable_messages,
240 };
241
242 let request = Some(serde_json::to_string(&serializable_request)?);
243 let cache_row = CacheRow {
244 request_hash,
245 request,
246 response: None,
247 batch_id: None,
248 };
249 self.connection.exec_bound(sql!(
250 INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
251 cache_row,
252 )
253 }
254
255 async fn generate(
256 &self,
257 model: &str,
258 max_tokens: u64,
259 messages: Vec<Message>,
260 ) -> Result<Option<AnthropicResponse>> {
261 let response = self.lookup(model, max_tokens, &messages)?;
262 if let Some(response) = response {
263 return Ok(Some(response));
264 }
265
266 self.mark_for_batch(model, max_tokens, &messages)?;
267
268 Ok(None)
269 }
270
271 /// Uploads pending requests as a new batch; downloads finished batches if any.
272 async fn sync_batches(&self) -> Result<()> {
273 self.upload_pending_requests().await?;
274 self.download_finished_batches().await
275 }
276
277 async fn download_finished_batches(&self) -> Result<()> {
278 let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
279 let batch_ids: Vec<String> = self.connection.select(q)?()?;
280
281 for batch_id in batch_ids {
282 let batch_status = anthropic::batches::retrieve_batch(
283 self.http_client.as_ref(),
284 ANTHROPIC_API_URL,
285 &self.api_key,
286 &batch_id,
287 )
288 .await
289 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
290
291 log::info!(
292 "Batch {} status: {}",
293 batch_id,
294 batch_status.processing_status
295 );
296
297 if batch_status.processing_status == "ended" {
298 let results = anthropic::batches::retrieve_batch_results(
299 self.http_client.as_ref(),
300 ANTHROPIC_API_URL,
301 &self.api_key,
302 &batch_id,
303 )
304 .await
305 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
306
307 let mut success_count = 0;
308 for result in results {
309 let request_hash = result
310 .custom_id
311 .strip_prefix("req_hash_")
312 .unwrap_or(&result.custom_id)
313 .to_string();
314
315 match result.result {
316 anthropic::batches::BatchResult::Succeeded { message } => {
317 let response_json = serde_json::to_string(&message)?;
318 let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
319 self.connection.exec_bound(q)?((response_json, request_hash))?;
320 success_count += 1;
321 }
322 anthropic::batches::BatchResult::Errored { error } => {
323 log::error!("Batch request {} failed: {:?}", request_hash, error);
324 }
325 anthropic::batches::BatchResult::Canceled => {
326 log::warn!("Batch request {} was canceled", request_hash);
327 }
328 anthropic::batches::BatchResult::Expired => {
329 log::warn!("Batch request {} expired", request_hash);
330 }
331 }
332 }
333 log::info!("Downloaded {} successful requests", success_count);
334 }
335 }
336
337 Ok(())
338 }
339
340 async fn upload_pending_requests(&self) -> Result<String> {
341 let q = sql!(
342 SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
343 );
344
345 let rows: Vec<(String, String)> = self.connection.select(q)?()?;
346
347 if rows.is_empty() {
348 return Ok(String::new());
349 }
350
351 let batch_requests = rows
352 .iter()
353 .map(|(hash, request_str)| {
354 let serializable_request: SerializableRequest =
355 serde_json::from_str(&request_str).unwrap();
356
357 let messages: Vec<Message> = serializable_request
358 .messages
359 .into_iter()
360 .map(|msg| Message {
361 role: match msg.role.as_str() {
362 "user" => Role::User,
363 "assistant" => Role::Assistant,
364 _ => Role::User,
365 },
366 content: vec![RequestContent::Text {
367 text: msg.content,
368 cache_control: None,
369 }],
370 })
371 .collect();
372
373 let params = AnthropicRequest {
374 model: serializable_request.model,
375 max_tokens: serializable_request.max_tokens,
376 messages,
377 tools: Vec::new(),
378 thinking: None,
379 tool_choice: None,
380 system: None,
381 metadata: None,
382 stop_sequences: Vec::new(),
383 temperature: None,
384 top_k: None,
385 top_p: None,
386 };
387
388 let custom_id = format!("req_hash_{}", hash);
389 anthropic::batches::BatchRequest { custom_id, params }
390 })
391 .collect::<Vec<_>>();
392
393 let batch_len = batch_requests.len();
394 let batch = anthropic::batches::create_batch(
395 self.http_client.as_ref(),
396 ANTHROPIC_API_URL,
397 &self.api_key,
398 anthropic::batches::CreateBatchRequest {
399 requests: batch_requests,
400 },
401 )
402 .await
403 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
404
405 let q = sql!(
406 UPDATE cache SET batch_id = ? WHERE batch_id is NULL
407 );
408 self.connection.exec_bound(q)?(batch.id.as_str())?;
409
410 log::info!("Uploaded batch with {} requests", batch_len);
411
412 Ok(batch.id)
413 }
414
415 fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
416 let mut hasher = std::hash::DefaultHasher::new();
417 model.hash(&mut hasher);
418 max_tokens.hash(&mut hasher);
419 for msg in messages {
420 message_content_to_string(&msg.content).hash(&mut hasher);
421 }
422 let request_hash = hasher.finish();
423 format!("{request_hash:016x}")
424 }
425}
426
427fn message_content_to_string(content: &[RequestContent]) -> String {
428 content
429 .iter()
430 .filter_map(|c| match c {
431 RequestContent::Text { text, .. } => Some(text.clone()),
432 _ => None,
433 })
434 .collect::<Vec<String>>()
435 .join("\n")
436}
437
438pub enum AnthropicClient {
439 // No batching
440 Plain(PlainLlmClient),
441 Batch(BatchingLlmClient),
442 Dummy,
443}
444
445impl AnthropicClient {
446 pub fn plain() -> Result<Self> {
447 Ok(Self::Plain(PlainLlmClient::new()?))
448 }
449
450 pub fn batch(cache_path: &Path) -> Result<Self> {
451 Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
452 }
453
454 #[allow(dead_code)]
455 pub fn dummy() -> Self {
456 Self::Dummy
457 }
458
459 pub async fn generate(
460 &self,
461 model: &str,
462 max_tokens: u64,
463 messages: Vec<Message>,
464 ) -> Result<Option<AnthropicResponse>> {
465 match self {
466 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
467 .generate(model, max_tokens, messages)
468 .await
469 .map(Some),
470 AnthropicClient::Batch(batching_llm_client) => {
471 batching_llm_client
472 .generate(model, max_tokens, messages)
473 .await
474 }
475 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
476 }
477 }
478
479 #[allow(dead_code)]
480 pub async fn generate_streaming<F>(
481 &self,
482 model: &str,
483 max_tokens: u64,
484 messages: Vec<Message>,
485 on_progress: F,
486 ) -> Result<Option<AnthropicResponse>>
487 where
488 F: FnMut(usize, &str),
489 {
490 match self {
491 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
492 .generate_streaming(model, max_tokens, messages, on_progress)
493 .await
494 .map(Some),
495 AnthropicClient::Batch(_) => {
496 anyhow::bail!("Streaming not supported with batching client")
497 }
498 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
499 }
500 }
501
502 pub async fn sync_batches(&self) -> Result<()> {
503 match self {
504 AnthropicClient::Plain(_) => Ok(()),
505 AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
506 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
507 }
508 }
509}