anthropic_client.rs

  1use anthropic::{
  2    ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
  3    Response as AnthropicResponse, Role, non_streaming_completion,
  4};
  5use anyhow::Result;
  6use http_client::HttpClient;
  7use indoc::indoc;
  8use reqwest_client::ReqwestClient;
  9use sqlez::bindable::Bind;
 10use sqlez::bindable::StaticColumnCount;
 11use sqlez_macros::sql;
 12use std::hash::Hash;
 13use std::hash::Hasher;
 14use std::path::Path;
 15use std::sync::Arc;
 16
 17pub struct PlainLlmClient {
 18    http_client: Arc<dyn HttpClient>,
 19    api_key: String,
 20}
 21
 22impl PlainLlmClient {
 23    fn new() -> Result<Self> {
 24        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 25        let api_key = std::env::var("ANTHROPIC_API_KEY")
 26            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 27        Ok(Self {
 28            http_client,
 29            api_key,
 30        })
 31    }
 32
 33    async fn generate(
 34        &self,
 35        model: &str,
 36        max_tokens: u64,
 37        messages: Vec<Message>,
 38    ) -> Result<AnthropicResponse> {
 39        let request = AnthropicRequest {
 40            model: model.to_string(),
 41            max_tokens,
 42            messages,
 43            tools: Vec::new(),
 44            thinking: None,
 45            tool_choice: None,
 46            system: None,
 47            metadata: None,
 48            stop_sequences: Vec::new(),
 49            temperature: None,
 50            top_k: None,
 51            top_p: None,
 52        };
 53
 54        let response = non_streaming_completion(
 55            self.http_client.as_ref(),
 56            ANTHROPIC_API_URL,
 57            &self.api_key,
 58            request,
 59            None,
 60        )
 61        .await
 62        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 63
 64        Ok(response)
 65    }
 66}
 67
 68pub struct BatchingLlmClient {
 69    connection: sqlez::connection::Connection,
 70    http_client: Arc<dyn HttpClient>,
 71    api_key: String,
 72}
 73
 74struct CacheRow {
 75    request_hash: String,
 76    request: Option<String>,
 77    response: Option<String>,
 78    batch_id: Option<String>,
 79}
 80
 81impl StaticColumnCount for CacheRow {
 82    fn column_count() -> usize {
 83        4
 84    }
 85}
 86
 87impl Bind for CacheRow {
 88    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 89        let next_index = statement.bind(&self.request_hash, start_index)?;
 90        let next_index = statement.bind(&self.request, next_index)?;
 91        let next_index = statement.bind(&self.response, next_index)?;
 92        let next_index = statement.bind(&self.batch_id, next_index)?;
 93        Ok(next_index)
 94    }
 95}
 96
 97#[derive(serde::Serialize, serde::Deserialize)]
 98struct SerializableRequest {
 99    model: String,
100    max_tokens: u64,
101    messages: Vec<SerializableMessage>,
102}
103
104#[derive(serde::Serialize, serde::Deserialize)]
105struct SerializableMessage {
106    role: String,
107    content: String,
108}
109
110impl BatchingLlmClient {
111    fn new(cache_path: &Path) -> Result<Self> {
112        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
113        let api_key = std::env::var("ANTHROPIC_API_KEY")
114            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
115
116        let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
117        let mut statement = sqlez::statement::Statement::prepare(
118            &connection,
119            indoc! {"
120                CREATE TABLE IF NOT EXISTS cache (
121                    request_hash TEXT PRIMARY KEY,
122                    request TEXT,
123                    response TEXT,
124                    batch_id TEXT
125                );
126                "},
127        )?;
128        statement.exec()?;
129        drop(statement);
130
131        Ok(Self {
132            connection,
133            http_client,
134            api_key,
135        })
136    }
137
138    pub fn lookup(
139        &self,
140        model: &str,
141        max_tokens: u64,
142        messages: &[Message],
143    ) -> Result<Option<AnthropicResponse>> {
144        let request_hash_str = Self::request_hash(model, max_tokens, messages);
145        let response: Vec<String> = self.connection.select_bound(
146            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
147        )?(request_hash_str.as_str())?;
148        Ok(response
149            .into_iter()
150            .next()
151            .and_then(|text| serde_json::from_str(&text).ok()))
152    }
153
154    pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
155        let request_hash = Self::request_hash(model, max_tokens, messages);
156
157        let serializable_messages: Vec<SerializableMessage> = messages
158            .iter()
159            .map(|msg| SerializableMessage {
160                role: match msg.role {
161                    Role::User => "user".to_string(),
162                    Role::Assistant => "assistant".to_string(),
163                },
164                content: message_content_to_string(&msg.content),
165            })
166            .collect();
167
168        let serializable_request = SerializableRequest {
169            model: model.to_string(),
170            max_tokens,
171            messages: serializable_messages,
172        };
173
174        let request = Some(serde_json::to_string(&serializable_request)?);
175        let cache_row = CacheRow {
176            request_hash,
177            request,
178            response: None,
179            batch_id: None,
180        };
181        self.connection.exec_bound(sql!(
182            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
183            cache_row,
184        )
185    }
186
187    async fn generate(
188        &self,
189        model: &str,
190        max_tokens: u64,
191        messages: Vec<Message>,
192    ) -> Result<Option<AnthropicResponse>> {
193        let response = self.lookup(model, max_tokens, &messages)?;
194        if let Some(response) = response {
195            return Ok(Some(response));
196        }
197
198        self.mark_for_batch(model, max_tokens, &messages)?;
199
200        Ok(None)
201    }
202
203    /// Uploads pending requests as a new batch; downloads finished batches if any.
204    async fn sync_batches(&self) -> Result<()> {
205        self.upload_pending_requests().await?;
206        self.download_finished_batches().await
207    }
208
209    async fn download_finished_batches(&self) -> Result<()> {
210        let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
211        let batch_ids: Vec<String> = self.connection.select(q)?()?;
212
213        for batch_id in batch_ids {
214            let batch_status = anthropic::batches::retrieve_batch(
215                self.http_client.as_ref(),
216                ANTHROPIC_API_URL,
217                &self.api_key,
218                &batch_id,
219            )
220            .await
221            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
222
223            log::info!(
224                "Batch {} status: {}",
225                batch_id,
226                batch_status.processing_status
227            );
228
229            if batch_status.processing_status == "ended" {
230                let results = anthropic::batches::retrieve_batch_results(
231                    self.http_client.as_ref(),
232                    ANTHROPIC_API_URL,
233                    &self.api_key,
234                    &batch_id,
235                )
236                .await
237                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
238
239                let mut success_count = 0;
240                for result in results {
241                    let request_hash = result
242                        .custom_id
243                        .strip_prefix("req_hash_")
244                        .unwrap_or(&result.custom_id)
245                        .to_string();
246
247                    match result.result {
248                        anthropic::batches::BatchResult::Succeeded { message } => {
249                            let response_json = serde_json::to_string(&message)?;
250                            let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
251                            self.connection.exec_bound(q)?((response_json, request_hash))?;
252                            success_count += 1;
253                        }
254                        anthropic::batches::BatchResult::Errored { error } => {
255                            log::error!("Batch request {} failed: {:?}", request_hash, error);
256                        }
257                        anthropic::batches::BatchResult::Canceled => {
258                            log::warn!("Batch request {} was canceled", request_hash);
259                        }
260                        anthropic::batches::BatchResult::Expired => {
261                            log::warn!("Batch request {} expired", request_hash);
262                        }
263                    }
264                }
265                log::info!("Downloaded {} successful requests", success_count);
266            }
267        }
268
269        Ok(())
270    }
271
272    async fn upload_pending_requests(&self) -> Result<String> {
273        let q = sql!(
274        SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
275        );
276
277        let rows: Vec<(String, String)> = self.connection.select(q)?()?;
278
279        if rows.is_empty() {
280            return Ok(String::new());
281        }
282
283        let batch_requests = rows
284            .iter()
285            .map(|(hash, request_str)| {
286                let serializable_request: SerializableRequest =
287                    serde_json::from_str(&request_str).unwrap();
288
289                let messages: Vec<Message> = serializable_request
290                    .messages
291                    .into_iter()
292                    .map(|msg| Message {
293                        role: match msg.role.as_str() {
294                            "user" => Role::User,
295                            "assistant" => Role::Assistant,
296                            _ => Role::User,
297                        },
298                        content: vec![RequestContent::Text {
299                            text: msg.content,
300                            cache_control: None,
301                        }],
302                    })
303                    .collect();
304
305                let params = AnthropicRequest {
306                    model: serializable_request.model,
307                    max_tokens: serializable_request.max_tokens,
308                    messages,
309                    tools: Vec::new(),
310                    thinking: None,
311                    tool_choice: None,
312                    system: None,
313                    metadata: None,
314                    stop_sequences: Vec::new(),
315                    temperature: None,
316                    top_k: None,
317                    top_p: None,
318                };
319
320                let custom_id = format!("req_hash_{}", hash);
321                anthropic::batches::BatchRequest { custom_id, params }
322            })
323            .collect::<Vec<_>>();
324
325        let batch_len = batch_requests.len();
326        let batch = anthropic::batches::create_batch(
327            self.http_client.as_ref(),
328            ANTHROPIC_API_URL,
329            &self.api_key,
330            anthropic::batches::CreateBatchRequest {
331                requests: batch_requests,
332            },
333        )
334        .await
335        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
336
337        let q = sql!(
338            UPDATE cache SET batch_id = ? WHERE batch_id is NULL
339        );
340        self.connection.exec_bound(q)?(batch.id.as_str())?;
341
342        log::info!("Uploaded batch with {} requests", batch_len);
343
344        Ok(batch.id)
345    }
346
347    fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
348        let mut hasher = std::hash::DefaultHasher::new();
349        model.hash(&mut hasher);
350        max_tokens.hash(&mut hasher);
351        for msg in messages {
352            message_content_to_string(&msg.content).hash(&mut hasher);
353        }
354        let request_hash = hasher.finish();
355        format!("{request_hash:016x}")
356    }
357}
358
359fn message_content_to_string(content: &[RequestContent]) -> String {
360    content
361        .iter()
362        .filter_map(|c| match c {
363            RequestContent::Text { text, .. } => Some(text.clone()),
364            _ => None,
365        })
366        .collect::<Vec<String>>()
367        .join("\n")
368}
369
370pub enum AnthropicClient {
371    // No batching
372    Plain(PlainLlmClient),
373    Batch(BatchingLlmClient),
374    Dummy,
375}
376
377impl AnthropicClient {
378    pub fn plain() -> Result<Self> {
379        Ok(Self::Plain(PlainLlmClient::new()?))
380    }
381
382    pub fn batch(cache_path: &Path) -> Result<Self> {
383        Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
384    }
385
386    #[allow(dead_code)]
387    pub fn dummy() -> Self {
388        Self::Dummy
389    }
390
391    pub async fn generate(
392        &self,
393        model: &str,
394        max_tokens: u64,
395        messages: Vec<Message>,
396    ) -> Result<Option<AnthropicResponse>> {
397        match self {
398            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
399                .generate(model, max_tokens, messages)
400                .await
401                .map(Some),
402            AnthropicClient::Batch(batching_llm_client) => {
403                batching_llm_client
404                    .generate(model, max_tokens, messages)
405                    .await
406            }
407            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
408        }
409    }
410
411    pub async fn sync_batches(&self) -> Result<()> {
412        match self {
413            AnthropicClient::Plain(_) => Ok(()),
414            AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
415            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
416        }
417    }
418}