1use anthropic::{
2 ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
3 Response as AnthropicResponse, Role, non_streaming_completion,
4};
5use anyhow::Result;
6use http_client::HttpClient;
7use indoc::indoc;
8use reqwest_client::ReqwestClient;
9use sqlez::bindable::Bind;
10use sqlez::bindable::StaticColumnCount;
11use sqlez_macros::sql;
12use std::hash::Hash;
13use std::hash::Hasher;
14use std::path::Path;
15use std::sync::Arc;
16
17pub struct PlainLlmClient {
18 http_client: Arc<dyn HttpClient>,
19 api_key: String,
20}
21
22impl PlainLlmClient {
23 fn new() -> Result<Self> {
24 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
25 let api_key = std::env::var("ANTHROPIC_API_KEY")
26 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
27 Ok(Self {
28 http_client,
29 api_key,
30 })
31 }
32
33 async fn generate(
34 &self,
35 model: &str,
36 max_tokens: u64,
37 messages: Vec<Message>,
38 ) -> Result<AnthropicResponse> {
39 let request = AnthropicRequest {
40 model: model.to_string(),
41 max_tokens,
42 messages,
43 tools: Vec::new(),
44 thinking: None,
45 tool_choice: None,
46 system: None,
47 metadata: None,
48 stop_sequences: Vec::new(),
49 temperature: None,
50 top_k: None,
51 top_p: None,
52 };
53
54 let response = non_streaming_completion(
55 self.http_client.as_ref(),
56 ANTHROPIC_API_URL,
57 &self.api_key,
58 request,
59 None,
60 )
61 .await
62 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
63
64 Ok(response)
65 }
66}
67
68pub struct BatchingLlmClient {
69 connection: sqlez::connection::Connection,
70 http_client: Arc<dyn HttpClient>,
71 api_key: String,
72}
73
74struct CacheRow {
75 request_hash: String,
76 request: Option<String>,
77 response: Option<String>,
78 batch_id: Option<String>,
79}
80
81impl StaticColumnCount for CacheRow {
82 fn column_count() -> usize {
83 4
84 }
85}
86
87impl Bind for CacheRow {
88 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
89 let next_index = statement.bind(&self.request_hash, start_index)?;
90 let next_index = statement.bind(&self.request, next_index)?;
91 let next_index = statement.bind(&self.response, next_index)?;
92 let next_index = statement.bind(&self.batch_id, next_index)?;
93 Ok(next_index)
94 }
95}
96
97#[derive(serde::Serialize, serde::Deserialize)]
98struct SerializableRequest {
99 model: String,
100 max_tokens: u64,
101 messages: Vec<SerializableMessage>,
102}
103
104#[derive(serde::Serialize, serde::Deserialize)]
105struct SerializableMessage {
106 role: String,
107 content: String,
108}
109
110impl BatchingLlmClient {
111 fn new(cache_path: &Path) -> Result<Self> {
112 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
113 let api_key = std::env::var("ANTHROPIC_API_KEY")
114 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
115
116 let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
117 let mut statement = sqlez::statement::Statement::prepare(
118 &connection,
119 indoc! {"
120 CREATE TABLE IF NOT EXISTS cache (
121 request_hash TEXT PRIMARY KEY,
122 request TEXT,
123 response TEXT,
124 batch_id TEXT
125 );
126 "},
127 )?;
128 statement.exec()?;
129 drop(statement);
130
131 Ok(Self {
132 connection,
133 http_client,
134 api_key,
135 })
136 }
137
138 pub fn lookup(
139 &self,
140 model: &str,
141 max_tokens: u64,
142 messages: &[Message],
143 ) -> Result<Option<AnthropicResponse>> {
144 let request_hash_str = Self::request_hash(model, max_tokens, messages);
145 let response: Vec<String> = self.connection.select_bound(
146 &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
147 )?(request_hash_str.as_str())?;
148 Ok(response
149 .into_iter()
150 .next()
151 .and_then(|text| serde_json::from_str(&text).ok()))
152 }
153
154 pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
155 let request_hash = Self::request_hash(model, max_tokens, messages);
156
157 let serializable_messages: Vec<SerializableMessage> = messages
158 .iter()
159 .map(|msg| SerializableMessage {
160 role: match msg.role {
161 Role::User => "user".to_string(),
162 Role::Assistant => "assistant".to_string(),
163 },
164 content: message_content_to_string(&msg.content),
165 })
166 .collect();
167
168 let serializable_request = SerializableRequest {
169 model: model.to_string(),
170 max_tokens,
171 messages: serializable_messages,
172 };
173
174 let request = Some(serde_json::to_string(&serializable_request)?);
175 let cache_row = CacheRow {
176 request_hash,
177 request,
178 response: None,
179 batch_id: None,
180 };
181 self.connection.exec_bound(sql!(
182 INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
183 cache_row,
184 )
185 }
186
187 async fn generate(
188 &self,
189 model: &str,
190 max_tokens: u64,
191 messages: Vec<Message>,
192 ) -> Result<Option<AnthropicResponse>> {
193 let response = self.lookup(model, max_tokens, &messages)?;
194 if let Some(response) = response {
195 return Ok(Some(response));
196 }
197
198 self.mark_for_batch(model, max_tokens, &messages)?;
199
200 Ok(None)
201 }
202
203 /// Uploads pending requests as a new batch; downloads finished batches if any.
204 async fn sync_batches(&self) -> Result<()> {
205 self.upload_pending_requests().await?;
206 self.download_finished_batches().await
207 }
208
209 async fn download_finished_batches(&self) -> Result<()> {
210 let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
211 let batch_ids: Vec<String> = self.connection.select(q)?()?;
212
213 for batch_id in batch_ids {
214 let batch_status = anthropic::batches::retrieve_batch(
215 self.http_client.as_ref(),
216 ANTHROPIC_API_URL,
217 &self.api_key,
218 &batch_id,
219 )
220 .await
221 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
222
223 log::info!(
224 "Batch {} status: {}",
225 batch_id,
226 batch_status.processing_status
227 );
228
229 if batch_status.processing_status == "ended" {
230 let results = anthropic::batches::retrieve_batch_results(
231 self.http_client.as_ref(),
232 ANTHROPIC_API_URL,
233 &self.api_key,
234 &batch_id,
235 )
236 .await
237 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
238
239 let mut success_count = 0;
240 for result in results {
241 let request_hash = result
242 .custom_id
243 .strip_prefix("req_hash_")
244 .unwrap_or(&result.custom_id)
245 .to_string();
246
247 match result.result {
248 anthropic::batches::BatchResult::Succeeded { message } => {
249 let response_json = serde_json::to_string(&message)?;
250 let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
251 self.connection.exec_bound(q)?((response_json, request_hash))?;
252 success_count += 1;
253 }
254 anthropic::batches::BatchResult::Errored { error } => {
255 log::error!("Batch request {} failed: {:?}", request_hash, error);
256 }
257 anthropic::batches::BatchResult::Canceled => {
258 log::warn!("Batch request {} was canceled", request_hash);
259 }
260 anthropic::batches::BatchResult::Expired => {
261 log::warn!("Batch request {} expired", request_hash);
262 }
263 }
264 }
265 log::info!("Downloaded {} successful requests", success_count);
266 }
267 }
268
269 Ok(())
270 }
271
272 async fn upload_pending_requests(&self) -> Result<String> {
273 let q = sql!(
274 SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
275 );
276
277 let rows: Vec<(String, String)> = self.connection.select(q)?()?;
278
279 if rows.is_empty() {
280 return Ok(String::new());
281 }
282
283 let batch_requests = rows
284 .iter()
285 .map(|(hash, request_str)| {
286 let serializable_request: SerializableRequest =
287 serde_json::from_str(&request_str).unwrap();
288
289 let messages: Vec<Message> = serializable_request
290 .messages
291 .into_iter()
292 .map(|msg| Message {
293 role: match msg.role.as_str() {
294 "user" => Role::User,
295 "assistant" => Role::Assistant,
296 _ => Role::User,
297 },
298 content: vec![RequestContent::Text {
299 text: msg.content,
300 cache_control: None,
301 }],
302 })
303 .collect();
304
305 let params = AnthropicRequest {
306 model: serializable_request.model,
307 max_tokens: serializable_request.max_tokens,
308 messages,
309 tools: Vec::new(),
310 thinking: None,
311 tool_choice: None,
312 system: None,
313 metadata: None,
314 stop_sequences: Vec::new(),
315 temperature: None,
316 top_k: None,
317 top_p: None,
318 };
319
320 let custom_id = format!("req_hash_{}", hash);
321 anthropic::batches::BatchRequest { custom_id, params }
322 })
323 .collect::<Vec<_>>();
324
325 let batch_len = batch_requests.len();
326 let batch = anthropic::batches::create_batch(
327 self.http_client.as_ref(),
328 ANTHROPIC_API_URL,
329 &self.api_key,
330 anthropic::batches::CreateBatchRequest {
331 requests: batch_requests,
332 },
333 )
334 .await
335 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
336
337 let q = sql!(
338 UPDATE cache SET batch_id = ? WHERE batch_id is NULL
339 );
340 self.connection.exec_bound(q)?(batch.id.as_str())?;
341
342 log::info!("Uploaded batch with {} requests", batch_len);
343
344 Ok(batch.id)
345 }
346
347 fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
348 let mut hasher = std::hash::DefaultHasher::new();
349 model.hash(&mut hasher);
350 max_tokens.hash(&mut hasher);
351 for msg in messages {
352 message_content_to_string(&msg.content).hash(&mut hasher);
353 }
354 let request_hash = hasher.finish();
355 format!("{request_hash:016x}")
356 }
357}
358
359fn message_content_to_string(content: &[RequestContent]) -> String {
360 content
361 .iter()
362 .filter_map(|c| match c {
363 RequestContent::Text { text, .. } => Some(text.clone()),
364 _ => None,
365 })
366 .collect::<Vec<String>>()
367 .join("\n")
368}
369
370pub enum AnthropicClient {
371 // No batching
372 Plain(PlainLlmClient),
373 Batch(BatchingLlmClient),
374 Dummy,
375}
376
377impl AnthropicClient {
378 pub fn plain() -> Result<Self> {
379 Ok(Self::Plain(PlainLlmClient::new()?))
380 }
381
382 pub fn batch(cache_path: &Path) -> Result<Self> {
383 Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
384 }
385
386 #[allow(dead_code)]
387 pub fn dummy() -> Self {
388 Self::Dummy
389 }
390
391 pub async fn generate(
392 &self,
393 model: &str,
394 max_tokens: u64,
395 messages: Vec<Message>,
396 ) -> Result<Option<AnthropicResponse>> {
397 match self {
398 AnthropicClient::Plain(plain_llm_client) => plain_llm_client
399 .generate(model, max_tokens, messages)
400 .await
401 .map(Some),
402 AnthropicClient::Batch(batching_llm_client) => {
403 batching_llm_client
404 .generate(model, max_tokens, messages)
405 .await
406 }
407 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
408 }
409 }
410
411 pub async fn sync_batches(&self) -> Result<()> {
412 match self {
413 AnthropicClient::Plain(_) => Ok(()),
414 AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
415 AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
416 }
417 }
418}