1use anthropic::{
2 ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
3 Response as AnthropicResponse, Role, non_streaming_completion,
4};
5use anyhow::Result;
6use http_client::HttpClient;
7use indoc::indoc;
8use sqlez::bindable::Bind;
9use sqlez::bindable::StaticColumnCount;
10use sqlez_macros::sql;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::sync::Arc;
14
15pub struct PlainLlmClient {
16 http_client: Arc<dyn HttpClient>,
17 api_key: String,
18}
19
20impl PlainLlmClient {
21 fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
22 let api_key = std::env::var("ANTHROPIC_API_KEY")
23 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
24 Ok(Self {
25 http_client,
26 api_key,
27 })
28 }
29
30 async fn generate(
31 &self,
32 model: String,
33 max_tokens: u64,
34 messages: Vec<Message>,
35 ) -> Result<AnthropicResponse> {
36 let request = AnthropicRequest {
37 model,
38 max_tokens,
39 messages,
40 tools: Vec::new(),
41 thinking: None,
42 tool_choice: None,
43 system: None,
44 metadata: None,
45 stop_sequences: Vec::new(),
46 temperature: None,
47 top_k: None,
48 top_p: None,
49 };
50
51 let response = non_streaming_completion(
52 self.http_client.as_ref(),
53 ANTHROPIC_API_URL,
54 &self.api_key,
55 request,
56 None,
57 )
58 .await
59 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
60
61 Ok(response)
62 }
63}
64
65pub struct BatchingLlmClient {
66 connection: sqlez::connection::Connection,
67 http_client: Arc<dyn HttpClient>,
68 api_key: String,
69}
70
71struct CacheRow {
72 request_hash: String,
73 request: Option<String>,
74 response: Option<String>,
75 batch_id: Option<String>,
76}
77
78impl StaticColumnCount for CacheRow {
79 fn column_count() -> usize {
80 4
81 }
82}
83
84impl Bind for CacheRow {
85 fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
86 let next_index = statement.bind(&self.request_hash, start_index)?;
87 let next_index = statement.bind(&self.request, next_index)?;
88 let next_index = statement.bind(&self.response, next_index)?;
89 let next_index = statement.bind(&self.batch_id, next_index)?;
90 Ok(next_index)
91 }
92}
93
94#[derive(serde::Serialize, serde::Deserialize)]
95struct SerializableRequest {
96 model: String,
97 max_tokens: u64,
98 messages: Vec<SerializableMessage>,
99}
100
101#[derive(serde::Serialize, serde::Deserialize)]
102struct SerializableMessage {
103 role: String,
104 content: String,
105}
106
107impl BatchingLlmClient {
108 fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
109 let api_key = std::env::var("ANTHROPIC_API_KEY")
110 .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
111
112 let connection = sqlez::connection::Connection::open_file(&cache_path);
113 let mut statement = sqlez::statement::Statement::prepare(
114 &connection,
115 indoc! {"
116 CREATE TABLE IF NOT EXISTS cache (
117 request_hash TEXT PRIMARY KEY,
118 request TEXT,
119 response TEXT,
120 batch_id TEXT
121 );
122 "},
123 )?;
124 statement.exec()?;
125 drop(statement);
126
127 Ok(Self {
128 connection,
129 http_client,
130 api_key,
131 })
132 }
133
134 pub fn lookup(
135 &self,
136 model: &str,
137 max_tokens: u64,
138 messages: &[Message],
139 ) -> Result<Option<AnthropicResponse>> {
140 let request_hash_str = Self::request_hash(model, max_tokens, messages);
141 let response: Vec<String> = self.connection.select_bound(
142 &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
143 )?(request_hash_str.as_str())?;
144 Ok(response
145 .into_iter()
146 .next()
147 .and_then(|text| serde_json::from_str(&text).ok()))
148 }
149
150 pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
151 let request_hash = Self::request_hash(model, max_tokens, messages);
152
153 let serializable_messages: Vec<SerializableMessage> = messages
154 .iter()
155 .map(|msg| SerializableMessage {
156 role: match msg.role {
157 Role::User => "user".to_string(),
158 Role::Assistant => "assistant".to_string(),
159 },
160 content: message_content_to_string(&msg.content),
161 })
162 .collect();
163
164 let serializable_request = SerializableRequest {
165 model: model.to_string(),
166 max_tokens,
167 messages: serializable_messages,
168 };
169
170 let request = Some(serde_json::to_string(&serializable_request)?);
171 let cache_row = CacheRow {
172 request_hash,
173 request,
174 response: None,
175 batch_id: None,
176 };
177 self.connection.exec_bound(sql!(
178 INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
179 cache_row,
180 )
181 }
182
183 async fn generate(
184 &self,
185 model: String,
186 max_tokens: u64,
187 messages: Vec<Message>,
188 ) -> Result<Option<AnthropicResponse>> {
189 let response = self.lookup(&model, max_tokens, &messages)?;
190 if let Some(response) = response {
191 return Ok(Some(response));
192 }
193
194 self.mark_for_batch(&model, max_tokens, &messages)?;
195
196 Ok(None)
197 }
198
199 /// Uploads pending requests as a new batch; downloads finished batches if any.
200 async fn sync_batches(&self) -> Result<()> {
201 self.upload_pending_requests().await?;
202 self.download_finished_batches().await
203 }
204
205 async fn download_finished_batches(&self) -> Result<()> {
206 let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
207 let batch_ids: Vec<String> = self.connection.select(q)?()?;
208
209 for batch_id in batch_ids {
210 let batch_status = anthropic::batches::retrieve_batch(
211 self.http_client.as_ref(),
212 ANTHROPIC_API_URL,
213 &self.api_key,
214 &batch_id,
215 )
216 .await
217 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
218
219 log::info!(
220 "Batch {} status: {}",
221 batch_id,
222 batch_status.processing_status
223 );
224
225 if batch_status.processing_status == "ended" {
226 let results = anthropic::batches::retrieve_batch_results(
227 self.http_client.as_ref(),
228 ANTHROPIC_API_URL,
229 &self.api_key,
230 &batch_id,
231 )
232 .await
233 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
234
235 let mut success_count = 0;
236 for result in results {
237 let request_hash = result
238 .custom_id
239 .strip_prefix("req_hash_")
240 .unwrap_or(&result.custom_id)
241 .to_string();
242
243 match result.result {
244 anthropic::batches::BatchResult::Succeeded { message } => {
245 let response_json = serde_json::to_string(&message)?;
246 let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
247 self.connection.exec_bound(q)?((response_json, request_hash))?;
248 success_count += 1;
249 }
250 anthropic::batches::BatchResult::Errored { error } => {
251 log::error!("Batch request {} failed: {:?}", request_hash, error);
252 }
253 anthropic::batches::BatchResult::Canceled => {
254 log::warn!("Batch request {} was canceled", request_hash);
255 }
256 anthropic::batches::BatchResult::Expired => {
257 log::warn!("Batch request {} expired", request_hash);
258 }
259 }
260 }
261 log::info!("Uploaded {} successful requests", success_count);
262 }
263 }
264
265 Ok(())
266 }
267
268 async fn upload_pending_requests(&self) -> Result<String> {
269 let q = sql!(
270 SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
271 );
272
273 let rows: Vec<(String, String)> = self.connection.select(q)?()?;
274
275 if rows.is_empty() {
276 return Ok(String::new());
277 }
278
279 let batch_requests = rows
280 .iter()
281 .map(|(hash, request_str)| {
282 let serializable_request: SerializableRequest =
283 serde_json::from_str(&request_str).unwrap();
284
285 let messages: Vec<Message> = serializable_request
286 .messages
287 .into_iter()
288 .map(|msg| Message {
289 role: match msg.role.as_str() {
290 "user" => Role::User,
291 "assistant" => Role::Assistant,
292 _ => Role::User,
293 },
294 content: vec![RequestContent::Text {
295 text: msg.content,
296 cache_control: None,
297 }],
298 })
299 .collect();
300
301 let params = AnthropicRequest {
302 model: serializable_request.model,
303 max_tokens: serializable_request.max_tokens,
304 messages,
305 tools: Vec::new(),
306 thinking: None,
307 tool_choice: None,
308 system: None,
309 metadata: None,
310 stop_sequences: Vec::new(),
311 temperature: None,
312 top_k: None,
313 top_p: None,
314 };
315
316 let custom_id = format!("req_hash_{}", hash);
317 anthropic::batches::BatchRequest { custom_id, params }
318 })
319 .collect::<Vec<_>>();
320
321 let batch_len = batch_requests.len();
322 let batch = anthropic::batches::create_batch(
323 self.http_client.as_ref(),
324 ANTHROPIC_API_URL,
325 &self.api_key,
326 anthropic::batches::CreateBatchRequest {
327 requests: batch_requests,
328 },
329 )
330 .await
331 .map_err(|e| anyhow::anyhow!("{:?}", e))?;
332
333 let q = sql!(
334 UPDATE cache SET batch_id = ? WHERE batch_id is NULL
335 );
336 self.connection.exec_bound(q)?(batch.id.as_str())?;
337
338 log::info!("Uploaded batch with {} requests", batch_len);
339
340 Ok(batch.id)
341 }
342
343 fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
344 let mut hasher = std::hash::DefaultHasher::new();
345 model.hash(&mut hasher);
346 max_tokens.hash(&mut hasher);
347 for msg in messages {
348 message_content_to_string(&msg.content).hash(&mut hasher);
349 }
350 let request_hash = hasher.finish();
351 format!("{request_hash:016x}")
352 }
353}
354
355fn message_content_to_string(content: &[RequestContent]) -> String {
356 content
357 .iter()
358 .filter_map(|c| match c {
359 RequestContent::Text { text, .. } => Some(text.clone()),
360 _ => None,
361 })
362 .collect::<Vec<String>>()
363 .join("\n")
364}
365
366pub enum LlmClient {
367 // No batching
368 Plain(PlainLlmClient),
369 Batch(BatchingLlmClient),
370 Dummy,
371}
372
373impl LlmClient {
374 pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
375 Ok(Self::Plain(PlainLlmClient::new(http_client)?))
376 }
377
378 pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
379 Ok(Self::Batch(BatchingLlmClient::new(
380 cache_path,
381 http_client,
382 )?))
383 }
384
385 #[allow(dead_code)]
386 pub fn dummy() -> Self {
387 Self::Dummy
388 }
389
390 pub async fn generate(
391 &self,
392 model: String,
393 max_tokens: u64,
394 messages: Vec<Message>,
395 ) -> Result<Option<AnthropicResponse>> {
396 match self {
397 LlmClient::Plain(plain_llm_client) => plain_llm_client
398 .generate(model, max_tokens, messages)
399 .await
400 .map(Some),
401 LlmClient::Batch(batching_llm_client) => {
402 batching_llm_client
403 .generate(model, max_tokens, messages)
404 .await
405 }
406 LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
407 }
408 }
409
410 pub async fn sync_batches(&self) -> Result<()> {
411 match self {
412 LlmClient::Plain(_) => Ok(()),
413 LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
414 LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
415 }
416 }
417}