anthropic_client.rs

  1use anthropic::{
  2    ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
  3    Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
  4    stream_completion,
  5};
  6use anyhow::Result;
  7use futures::StreamExt as _;
  8use http_client::HttpClient;
  9use indoc::indoc;
 10use reqwest_client::ReqwestClient;
 11use sqlez::bindable::Bind;
 12use sqlez::bindable::StaticColumnCount;
 13use sqlez_macros::sql;
 14use std::hash::Hash;
 15use std::hash::Hasher;
 16use std::path::Path;
 17use std::sync::Arc;
 18
 19pub struct PlainLlmClient {
 20    pub http_client: Arc<dyn HttpClient>,
 21    pub api_key: String,
 22}
 23
 24impl PlainLlmClient {
 25    pub fn new() -> Result<Self> {
 26        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 27        let api_key = std::env::var("ANTHROPIC_API_KEY")
 28            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 29        Ok(Self {
 30            http_client,
 31            api_key,
 32        })
 33    }
 34
 35    pub async fn generate(
 36        &self,
 37        model: &str,
 38        max_tokens: u64,
 39        messages: Vec<Message>,
 40    ) -> Result<AnthropicResponse> {
 41        let request = AnthropicRequest {
 42            model: model.to_string(),
 43            max_tokens,
 44            messages,
 45            tools: Vec::new(),
 46            thinking: None,
 47            tool_choice: None,
 48            system: None,
 49            metadata: None,
 50            stop_sequences: Vec::new(),
 51            temperature: None,
 52            top_k: None,
 53            top_p: None,
 54        };
 55
 56        let response = non_streaming_completion(
 57            self.http_client.as_ref(),
 58            ANTHROPIC_API_URL,
 59            &self.api_key,
 60            request,
 61            None,
 62        )
 63        .await
 64        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 65
 66        Ok(response)
 67    }
 68
 69    pub async fn generate_streaming<F>(
 70        &self,
 71        model: &str,
 72        max_tokens: u64,
 73        messages: Vec<Message>,
 74        mut on_progress: F,
 75    ) -> Result<AnthropicResponse>
 76    where
 77        F: FnMut(usize, &str),
 78    {
 79        let request = AnthropicRequest {
 80            model: model.to_string(),
 81            max_tokens,
 82            messages,
 83            tools: Vec::new(),
 84            thinking: None,
 85            tool_choice: None,
 86            system: None,
 87            metadata: None,
 88            stop_sequences: Vec::new(),
 89            temperature: None,
 90            top_k: None,
 91            top_p: None,
 92        };
 93
 94        let mut stream = stream_completion(
 95            self.http_client.as_ref(),
 96            ANTHROPIC_API_URL,
 97            &self.api_key,
 98            request,
 99            None,
100        )
101        .await
102        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
103
104        let mut response: Option<AnthropicResponse> = None;
105        let mut text_content = String::new();
106
107        while let Some(event_result) = stream.next().await {
108            let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
109
110            match event {
111                Event::MessageStart { message } => {
112                    response = Some(message);
113                }
114                Event::ContentBlockDelta { delta, .. } => {
115                    if let anthropic::ContentDelta::TextDelta { text } = delta {
116                        text_content.push_str(&text);
117                        on_progress(text_content.len(), &text_content);
118                    }
119                }
120                _ => {}
121            }
122        }
123
124        let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
125
126        if response.content.is_empty() && !text_content.is_empty() {
127            response
128                .content
129                .push(ResponseContent::Text { text: text_content });
130        }
131
132        Ok(response)
133    }
134}
135
136pub struct BatchingLlmClient {
137    connection: sqlez::connection::Connection,
138    http_client: Arc<dyn HttpClient>,
139    api_key: String,
140}
141
142struct CacheRow {
143    request_hash: String,
144    request: Option<String>,
145    response: Option<String>,
146    batch_id: Option<String>,
147}
148
149impl StaticColumnCount for CacheRow {
150    fn column_count() -> usize {
151        4
152    }
153}
154
155impl Bind for CacheRow {
156    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
157        let next_index = statement.bind(&self.request_hash, start_index)?;
158        let next_index = statement.bind(&self.request, next_index)?;
159        let next_index = statement.bind(&self.response, next_index)?;
160        let next_index = statement.bind(&self.batch_id, next_index)?;
161        Ok(next_index)
162    }
163}
164
165#[derive(serde::Serialize, serde::Deserialize)]
166struct SerializableRequest {
167    model: String,
168    max_tokens: u64,
169    messages: Vec<SerializableMessage>,
170}
171
172#[derive(serde::Serialize, serde::Deserialize)]
173struct SerializableMessage {
174    role: String,
175    content: String,
176}
177
178impl BatchingLlmClient {
179    fn new(cache_path: &Path) -> Result<Self> {
180        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
181        let api_key = std::env::var("ANTHROPIC_API_KEY")
182            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
183
184        let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
185        let mut statement = sqlez::statement::Statement::prepare(
186            &connection,
187            indoc! {"
188                CREATE TABLE IF NOT EXISTS cache (
189                    request_hash TEXT PRIMARY KEY,
190                    request TEXT,
191                    response TEXT,
192                    batch_id TEXT
193                );
194                "},
195        )?;
196        statement.exec()?;
197        drop(statement);
198
199        Ok(Self {
200            connection,
201            http_client,
202            api_key,
203        })
204    }
205
206    pub fn lookup(
207        &self,
208        model: &str,
209        max_tokens: u64,
210        messages: &[Message],
211    ) -> Result<Option<AnthropicResponse>> {
212        let request_hash_str = Self::request_hash(model, max_tokens, messages);
213        let response: Vec<String> = self.connection.select_bound(
214            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
215        )?(request_hash_str.as_str())?;
216        Ok(response
217            .into_iter()
218            .next()
219            .and_then(|text| serde_json::from_str(&text).ok()))
220    }
221
222    pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
223        let request_hash = Self::request_hash(model, max_tokens, messages);
224
225        let serializable_messages: Vec<SerializableMessage> = messages
226            .iter()
227            .map(|msg| SerializableMessage {
228                role: match msg.role {
229                    Role::User => "user".to_string(),
230                    Role::Assistant => "assistant".to_string(),
231                },
232                content: message_content_to_string(&msg.content),
233            })
234            .collect();
235
236        let serializable_request = SerializableRequest {
237            model: model.to_string(),
238            max_tokens,
239            messages: serializable_messages,
240        };
241
242        let request = Some(serde_json::to_string(&serializable_request)?);
243        let cache_row = CacheRow {
244            request_hash,
245            request,
246            response: None,
247            batch_id: None,
248        };
249        self.connection.exec_bound(sql!(
250            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
251            cache_row,
252        )
253    }
254
255    async fn generate(
256        &self,
257        model: &str,
258        max_tokens: u64,
259        messages: Vec<Message>,
260    ) -> Result<Option<AnthropicResponse>> {
261        let response = self.lookup(model, max_tokens, &messages)?;
262        if let Some(response) = response {
263            return Ok(Some(response));
264        }
265
266        self.mark_for_batch(model, max_tokens, &messages)?;
267
268        Ok(None)
269    }
270
271    /// Uploads pending requests as a new batch; downloads finished batches if any.
272    async fn sync_batches(&self) -> Result<()> {
273        self.upload_pending_requests().await?;
274        self.download_finished_batches().await
275    }
276
277    async fn download_finished_batches(&self) -> Result<()> {
278        let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
279        let batch_ids: Vec<String> = self.connection.select(q)?()?;
280
281        for batch_id in batch_ids {
282            let batch_status = anthropic::batches::retrieve_batch(
283                self.http_client.as_ref(),
284                ANTHROPIC_API_URL,
285                &self.api_key,
286                &batch_id,
287            )
288            .await
289            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
290
291            log::info!(
292                "Batch {} status: {}",
293                batch_id,
294                batch_status.processing_status
295            );
296
297            if batch_status.processing_status == "ended" {
298                let results = anthropic::batches::retrieve_batch_results(
299                    self.http_client.as_ref(),
300                    ANTHROPIC_API_URL,
301                    &self.api_key,
302                    &batch_id,
303                )
304                .await
305                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
306
307                let mut updates: Vec<(String, String)> = Vec::new();
308                let mut success_count = 0;
309                for result in results {
310                    let request_hash = result
311                        .custom_id
312                        .strip_prefix("req_hash_")
313                        .unwrap_or(&result.custom_id)
314                        .to_string();
315
316                    match result.result {
317                        anthropic::batches::BatchResult::Succeeded { message } => {
318                            let response_json = serde_json::to_string(&message)?;
319                            updates.push((response_json, request_hash));
320                            success_count += 1;
321                        }
322                        anthropic::batches::BatchResult::Errored { error } => {
323                            log::error!(
324                                "Batch request {} failed: {}: {}",
325                                request_hash,
326                                error.error.error_type,
327                                error.error.message
328                            );
329                            let error_json = serde_json::json!({
330                                "error": {
331                                    "type": error.error.error_type,
332                                    "message": error.error.message
333                                }
334                            })
335                            .to_string();
336                            updates.push((error_json, request_hash));
337                        }
338                        anthropic::batches::BatchResult::Canceled => {
339                            log::warn!("Batch request {} was canceled", request_hash);
340                            let error_json = serde_json::json!({
341                                "error": {
342                                    "type": "canceled",
343                                    "message": "Batch request was canceled"
344                                }
345                            })
346                            .to_string();
347                            updates.push((error_json, request_hash));
348                        }
349                        anthropic::batches::BatchResult::Expired => {
350                            log::warn!("Batch request {} expired", request_hash);
351                            let error_json = serde_json::json!({
352                                "error": {
353                                    "type": "expired",
354                                    "message": "Batch request expired"
355                                }
356                            })
357                            .to_string();
358                            updates.push((error_json, request_hash));
359                        }
360                    }
361                }
362
363                self.connection.with_savepoint("batch_download", || {
364                    let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
365                    let mut exec = self.connection.exec_bound::<(&str, &str)>(q)?;
366                    for (response_json, request_hash) in &updates {
367                        exec((response_json.as_str(), request_hash.as_str()))?;
368                    }
369                    Ok(())
370                })?;
371                log::info!("Downloaded {} successful requests", success_count);
372            }
373        }
374
375        Ok(())
376    }
377
378    async fn upload_pending_requests(&self) -> Result<String> {
379        let q = sql!(
380        SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
381        );
382
383        let rows: Vec<(String, String)> = self.connection.select(q)?()?;
384
385        if rows.is_empty() {
386            return Ok(String::new());
387        }
388
389        let batch_requests = rows
390            .iter()
391            .map(|(hash, request_str)| {
392                let serializable_request: SerializableRequest =
393                    serde_json::from_str(&request_str).unwrap();
394
395                let messages: Vec<Message> = serializable_request
396                    .messages
397                    .into_iter()
398                    .map(|msg| Message {
399                        role: match msg.role.as_str() {
400                            "user" => Role::User,
401                            "assistant" => Role::Assistant,
402                            _ => Role::User,
403                        },
404                        content: vec![RequestContent::Text {
405                            text: msg.content,
406                            cache_control: None,
407                        }],
408                    })
409                    .collect();
410
411                let params = AnthropicRequest {
412                    model: serializable_request.model,
413                    max_tokens: serializable_request.max_tokens,
414                    messages,
415                    tools: Vec::new(),
416                    thinking: None,
417                    tool_choice: None,
418                    system: None,
419                    metadata: None,
420                    stop_sequences: Vec::new(),
421                    temperature: None,
422                    top_k: None,
423                    top_p: None,
424                };
425
426                let custom_id = format!("req_hash_{}", hash);
427                anthropic::batches::BatchRequest { custom_id, params }
428            })
429            .collect::<Vec<_>>();
430
431        let batch_len = batch_requests.len();
432        let batch = anthropic::batches::create_batch(
433            self.http_client.as_ref(),
434            ANTHROPIC_API_URL,
435            &self.api_key,
436            anthropic::batches::CreateBatchRequest {
437                requests: batch_requests,
438            },
439        )
440        .await
441        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
442
443        let q = sql!(
444            UPDATE cache SET batch_id = ? WHERE batch_id is NULL
445        );
446        self.connection.exec_bound(q)?(batch.id.as_str())?;
447
448        log::info!("Uploaded batch with {} requests", batch_len);
449
450        Ok(batch.id)
451    }
452
453    fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
454        let mut hasher = std::hash::DefaultHasher::new();
455        model.hash(&mut hasher);
456        max_tokens.hash(&mut hasher);
457        for msg in messages {
458            message_content_to_string(&msg.content).hash(&mut hasher);
459        }
460        let request_hash = hasher.finish();
461        format!("{request_hash:016x}")
462    }
463}
464
465fn message_content_to_string(content: &[RequestContent]) -> String {
466    content
467        .iter()
468        .filter_map(|c| match c {
469            RequestContent::Text { text, .. } => Some(text.clone()),
470            _ => None,
471        })
472        .collect::<Vec<String>>()
473        .join("\n")
474}
475
476pub enum AnthropicClient {
477    // No batching
478    Plain(PlainLlmClient),
479    Batch(BatchingLlmClient),
480    Dummy,
481}
482
483impl AnthropicClient {
484    pub fn plain() -> Result<Self> {
485        Ok(Self::Plain(PlainLlmClient::new()?))
486    }
487
488    pub fn batch(cache_path: &Path) -> Result<Self> {
489        Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
490    }
491
492    #[allow(dead_code)]
493    pub fn dummy() -> Self {
494        Self::Dummy
495    }
496
497    pub async fn generate(
498        &self,
499        model: &str,
500        max_tokens: u64,
501        messages: Vec<Message>,
502    ) -> Result<Option<AnthropicResponse>> {
503        match self {
504            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
505                .generate(model, max_tokens, messages)
506                .await
507                .map(Some),
508            AnthropicClient::Batch(batching_llm_client) => {
509                batching_llm_client
510                    .generate(model, max_tokens, messages)
511                    .await
512            }
513            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
514        }
515    }
516
517    #[allow(dead_code)]
518    pub async fn generate_streaming<F>(
519        &self,
520        model: &str,
521        max_tokens: u64,
522        messages: Vec<Message>,
523        on_progress: F,
524    ) -> Result<Option<AnthropicResponse>>
525    where
526        F: FnMut(usize, &str),
527    {
528        match self {
529            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
530                .generate_streaming(model, max_tokens, messages, on_progress)
531                .await
532                .map(Some),
533            AnthropicClient::Batch(_) => {
534                anyhow::bail!("Streaming not supported with batching client")
535            }
536            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
537        }
538    }
539
540    pub async fn sync_batches(&self) -> Result<()> {
541        match self {
542            AnthropicClient::Plain(_) => Ok(()),
543            AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
544            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
545        }
546    }
547}