anthropic_client.rs

  1use anthropic::{
  2    ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
  3    Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
  4    stream_completion,
  5};
  6use anyhow::Result;
  7use futures::StreamExt as _;
  8use http_client::HttpClient;
  9use indoc::indoc;
 10use reqwest_client::ReqwestClient;
 11use sqlez::bindable::Bind;
 12use sqlez::bindable::StaticColumnCount;
 13use sqlez_macros::sql;
 14use std::hash::Hash;
 15use std::hash::Hasher;
 16use std::path::Path;
 17use std::sync::{Arc, Mutex};
 18
 19pub struct PlainLlmClient {
 20    pub http_client: Arc<dyn HttpClient>,
 21    pub api_key: String,
 22}
 23
 24impl PlainLlmClient {
 25    pub fn new() -> Result<Self> {
 26        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 27        let api_key = std::env::var("ANTHROPIC_API_KEY")
 28            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 29        Ok(Self {
 30            http_client,
 31            api_key,
 32        })
 33    }
 34
 35    pub async fn generate(
 36        &self,
 37        model: &str,
 38        max_tokens: u64,
 39        messages: Vec<Message>,
 40    ) -> Result<AnthropicResponse> {
 41        let request = AnthropicRequest {
 42            model: model.to_string(),
 43            max_tokens,
 44            messages,
 45            tools: Vec::new(),
 46            thinking: None,
 47            tool_choice: None,
 48            system: None,
 49            metadata: None,
 50            stop_sequences: Vec::new(),
 51            temperature: None,
 52            top_k: None,
 53            top_p: None,
 54        };
 55
 56        let response = non_streaming_completion(
 57            self.http_client.as_ref(),
 58            ANTHROPIC_API_URL,
 59            &self.api_key,
 60            request,
 61            None,
 62        )
 63        .await
 64        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 65
 66        Ok(response)
 67    }
 68
 69    pub async fn generate_streaming<F>(
 70        &self,
 71        model: &str,
 72        max_tokens: u64,
 73        messages: Vec<Message>,
 74        mut on_progress: F,
 75    ) -> Result<AnthropicResponse>
 76    where
 77        F: FnMut(usize, &str),
 78    {
 79        let request = AnthropicRequest {
 80            model: model.to_string(),
 81            max_tokens,
 82            messages,
 83            tools: Vec::new(),
 84            thinking: None,
 85            tool_choice: None,
 86            system: None,
 87            metadata: None,
 88            stop_sequences: Vec::new(),
 89            temperature: None,
 90            top_k: None,
 91            top_p: None,
 92        };
 93
 94        let mut stream = stream_completion(
 95            self.http_client.as_ref(),
 96            ANTHROPIC_API_URL,
 97            &self.api_key,
 98            request,
 99            None,
100        )
101        .await
102        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
103
104        let mut response: Option<AnthropicResponse> = None;
105        let mut text_content = String::new();
106
107        while let Some(event_result) = stream.next().await {
108            let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
109
110            match event {
111                Event::MessageStart { message } => {
112                    response = Some(message);
113                }
114                Event::ContentBlockDelta { delta, .. } => {
115                    if let anthropic::ContentDelta::TextDelta { text } = delta {
116                        text_content.push_str(&text);
117                        on_progress(text_content.len(), &text_content);
118                    }
119                }
120                _ => {}
121            }
122        }
123
124        let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
125
126        if response.content.is_empty() && !text_content.is_empty() {
127            response
128                .content
129                .push(ResponseContent::Text { text: text_content });
130        }
131
132        Ok(response)
133    }
134}
135
136pub struct BatchingLlmClient {
137    connection: Mutex<sqlez::connection::Connection>,
138    http_client: Arc<dyn HttpClient>,
139    api_key: String,
140}
141
142struct CacheRow {
143    request_hash: String,
144    request: Option<String>,
145    response: Option<String>,
146    batch_id: Option<String>,
147}
148
149impl StaticColumnCount for CacheRow {
150    fn column_count() -> usize {
151        4
152    }
153}
154
155impl Bind for CacheRow {
156    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
157        let next_index = statement.bind(&self.request_hash, start_index)?;
158        let next_index = statement.bind(&self.request, next_index)?;
159        let next_index = statement.bind(&self.response, next_index)?;
160        let next_index = statement.bind(&self.batch_id, next_index)?;
161        Ok(next_index)
162    }
163}
164
165#[derive(serde::Serialize, serde::Deserialize)]
166struct SerializableRequest {
167    model: String,
168    max_tokens: u64,
169    messages: Vec<SerializableMessage>,
170}
171
172#[derive(serde::Serialize, serde::Deserialize)]
173struct SerializableMessage {
174    role: String,
175    content: String,
176}
177
178impl BatchingLlmClient {
179    fn new(cache_path: &Path) -> Result<Self> {
180        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
181        let api_key = std::env::var("ANTHROPIC_API_KEY")
182            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
183
184        let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
185        let mut statement = sqlez::statement::Statement::prepare(
186            &connection,
187            indoc! {"
188                CREATE TABLE IF NOT EXISTS cache (
189                    request_hash TEXT PRIMARY KEY,
190                    request TEXT,
191                    response TEXT,
192                    batch_id TEXT
193                );
194                "},
195        )?;
196        statement.exec()?;
197        drop(statement);
198
199        Ok(Self {
200            connection: Mutex::new(connection),
201            http_client,
202            api_key,
203        })
204    }
205
206    pub fn lookup(
207        &self,
208        model: &str,
209        max_tokens: u64,
210        messages: &[Message],
211        seed: Option<usize>,
212    ) -> Result<Option<AnthropicResponse>> {
213        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
214        let connection = self.connection.lock().unwrap();
215        let response: Vec<String> = connection.select_bound(
216            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
217        )?(request_hash_str.as_str())?;
218        Ok(response
219            .into_iter()
220            .next()
221            .and_then(|text| serde_json::from_str(&text).ok()))
222    }
223
224    pub fn mark_for_batch(
225        &self,
226        model: &str,
227        max_tokens: u64,
228        messages: &[Message],
229        seed: Option<usize>,
230    ) -> Result<()> {
231        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
232
233        let serializable_messages: Vec<SerializableMessage> = messages
234            .iter()
235            .map(|msg| SerializableMessage {
236                role: match msg.role {
237                    Role::User => "user".to_string(),
238                    Role::Assistant => "assistant".to_string(),
239                },
240                content: message_content_to_string(&msg.content),
241            })
242            .collect();
243
244        let serializable_request = SerializableRequest {
245            model: model.to_string(),
246            max_tokens,
247            messages: serializable_messages,
248        };
249
250        let request = Some(serde_json::to_string(&serializable_request)?);
251        let cache_row = CacheRow {
252            request_hash,
253            request,
254            response: None,
255            batch_id: None,
256        };
257        let connection = self.connection.lock().unwrap();
258        connection.exec_bound::<CacheRow>(sql!(
259            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
260            cache_row,
261        )
262    }
263
264    async fn generate(
265        &self,
266        model: &str,
267        max_tokens: u64,
268        messages: Vec<Message>,
269        seed: Option<usize>,
270        cache_only: bool,
271    ) -> Result<Option<AnthropicResponse>> {
272        let response = self.lookup(model, max_tokens, &messages, seed)?;
273        if let Some(response) = response {
274            return Ok(Some(response));
275        }
276
277        if !cache_only {
278            self.mark_for_batch(model, max_tokens, &messages, seed)?;
279        }
280
281        Ok(None)
282    }
283
284    /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
285    async fn sync_batches(&self) -> Result<()> {
286        let _batch_ids = self.upload_pending_requests().await?;
287        self.download_finished_batches().await
288    }
289
290    /// Import batch results from external batch IDs (useful for recovering after database loss)
291    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
292        for batch_id in batch_ids {
293            log::info!("Importing batch {}", batch_id);
294
295            let batch_status = anthropic::batches::retrieve_batch(
296                self.http_client.as_ref(),
297                ANTHROPIC_API_URL,
298                &self.api_key,
299                batch_id,
300            )
301            .await
302            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
303
304            log::info!(
305                "Batch {} status: {}",
306                batch_id,
307                batch_status.processing_status
308            );
309
310            if batch_status.processing_status != "ended" {
311                log::warn!(
312                    "Batch {} is not finished (status: {}), skipping",
313                    batch_id,
314                    batch_status.processing_status
315                );
316                continue;
317            }
318
319            let results = anthropic::batches::retrieve_batch_results(
320                self.http_client.as_ref(),
321                ANTHROPIC_API_URL,
322                &self.api_key,
323                batch_id,
324            )
325            .await
326            .map_err(|e| {
327                anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
328            })?;
329
330            let mut updates: Vec<(String, String, String)> = Vec::new();
331            let mut success_count = 0;
332            let mut error_count = 0;
333
334            for result in results {
335                let request_hash = result
336                    .custom_id
337                    .strip_prefix("req_hash_")
338                    .unwrap_or(&result.custom_id)
339                    .to_string();
340
341                match result.result {
342                    anthropic::batches::BatchResult::Succeeded { message } => {
343                        let response_json = serde_json::to_string(&message)?;
344                        updates.push((request_hash, response_json, batch_id.clone()));
345                        success_count += 1;
346                    }
347                    anthropic::batches::BatchResult::Errored { error } => {
348                        log::error!(
349                            "Batch request {} failed: {}: {}",
350                            request_hash,
351                            error.error.error_type,
352                            error.error.message
353                        );
354                        let error_json = serde_json::json!({
355                            "error": {
356                                "type": error.error.error_type,
357                                "message": error.error.message
358                            }
359                        })
360                        .to_string();
361                        updates.push((request_hash, error_json, batch_id.clone()));
362                        error_count += 1;
363                    }
364                    anthropic::batches::BatchResult::Canceled => {
365                        log::warn!("Batch request {} was canceled", request_hash);
366                        error_count += 1;
367                    }
368                    anthropic::batches::BatchResult::Expired => {
369                        log::warn!("Batch request {} expired", request_hash);
370                        error_count += 1;
371                    }
372                }
373            }
374
375            let connection = self.connection.lock().unwrap();
376            connection.with_savepoint("batch_import", || {
377                // Use INSERT OR REPLACE to handle both new entries and updating existing ones
378                let q = sql!(
379                    INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
380                    VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
381                );
382                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
383                for (request_hash, response_json, batch_id) in &updates {
384                    exec((
385                        request_hash.as_str(),
386                        request_hash.as_str(),
387                        response_json.as_str(),
388                        batch_id.as_str(),
389                    ))?;
390                }
391                Ok(())
392            })?;
393
394            log::info!(
395                "Imported batch {}: {} successful, {} errors",
396                batch_id,
397                success_count,
398                error_count
399            );
400        }
401
402        Ok(())
403    }
404
405    async fn download_finished_batches(&self) -> Result<()> {
406        let batch_ids: Vec<String> = {
407            let connection = self.connection.lock().unwrap();
408            let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
409            connection.select(q)?()?
410        };
411
412        for batch_id in &batch_ids {
413            let batch_status = anthropic::batches::retrieve_batch(
414                self.http_client.as_ref(),
415                ANTHROPIC_API_URL,
416                &self.api_key,
417                &batch_id,
418            )
419            .await
420            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
421
422            log::info!(
423                "Batch {} status: {}",
424                batch_id,
425                batch_status.processing_status
426            );
427
428            if batch_status.processing_status == "ended" {
429                let results = anthropic::batches::retrieve_batch_results(
430                    self.http_client.as_ref(),
431                    ANTHROPIC_API_URL,
432                    &self.api_key,
433                    &batch_id,
434                )
435                .await
436                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
437
438                let mut updates: Vec<(String, String)> = Vec::new();
439                let mut success_count = 0;
440                for result in results {
441                    let request_hash = result
442                        .custom_id
443                        .strip_prefix("req_hash_")
444                        .unwrap_or(&result.custom_id)
445                        .to_string();
446
447                    match result.result {
448                        anthropic::batches::BatchResult::Succeeded { message } => {
449                            let response_json = serde_json::to_string(&message)?;
450                            updates.push((response_json, request_hash));
451                            success_count += 1;
452                        }
453                        anthropic::batches::BatchResult::Errored { error } => {
454                            log::error!(
455                                "Batch request {} failed: {}: {}",
456                                request_hash,
457                                error.error.error_type,
458                                error.error.message
459                            );
460                            let error_json = serde_json::json!({
461                                "error": {
462                                    "type": error.error.error_type,
463                                    "message": error.error.message
464                                }
465                            })
466                            .to_string();
467                            updates.push((error_json, request_hash));
468                        }
469                        anthropic::batches::BatchResult::Canceled => {
470                            log::warn!("Batch request {} was canceled", request_hash);
471                            let error_json = serde_json::json!({
472                                "error": {
473                                    "type": "canceled",
474                                    "message": "Batch request was canceled"
475                                }
476                            })
477                            .to_string();
478                            updates.push((error_json, request_hash));
479                        }
480                        anthropic::batches::BatchResult::Expired => {
481                            log::warn!("Batch request {} expired", request_hash);
482                            let error_json = serde_json::json!({
483                                "error": {
484                                    "type": "expired",
485                                    "message": "Batch request expired"
486                                }
487                            })
488                            .to_string();
489                            updates.push((error_json, request_hash));
490                        }
491                    }
492                }
493
494                let connection = self.connection.lock().unwrap();
495                connection.with_savepoint("batch_download", || {
496                    let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
497                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
498                    for (response_json, request_hash) in &updates {
499                        exec((response_json.as_str(), request_hash.as_str()))?;
500                    }
501                    Ok(())
502                })?;
503                log::info!("Downloaded {} successful requests", success_count);
504            }
505        }
506
507        Ok(())
508    }
509
510    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
511        const BATCH_CHUNK_SIZE: i32 = 16_000;
512        let mut all_batch_ids = Vec::new();
513        let mut total_uploaded = 0;
514
515        loop {
516            let rows: Vec<(String, String)> = {
517                let connection = self.connection.lock().unwrap();
518                let q = sql!(
519                    SELECT request_hash, request FROM cache
520                    WHERE batch_id IS NULL AND response IS NULL
521                    LIMIT ?
522                );
523                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
524            };
525
526            if rows.is_empty() {
527                break;
528            }
529
530            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
531
532            let batch_requests = rows
533                .iter()
534                .map(|(hash, request_str)| {
535                    let serializable_request: SerializableRequest =
536                        serde_json::from_str(&request_str).unwrap();
537
538                    let messages: Vec<Message> = serializable_request
539                        .messages
540                        .into_iter()
541                        .map(|msg| Message {
542                            role: match msg.role.as_str() {
543                                "user" => Role::User,
544                                "assistant" => Role::Assistant,
545                                _ => Role::User,
546                            },
547                            content: vec![RequestContent::Text {
548                                text: msg.content,
549                                cache_control: None,
550                            }],
551                        })
552                        .collect();
553
554                    let params = AnthropicRequest {
555                        model: serializable_request.model,
556                        max_tokens: serializable_request.max_tokens,
557                        messages,
558                        tools: Vec::new(),
559                        thinking: None,
560                        tool_choice: None,
561                        system: None,
562                        metadata: None,
563                        stop_sequences: Vec::new(),
564                        temperature: None,
565                        top_k: None,
566                        top_p: None,
567                    };
568
569                    let custom_id = format!("req_hash_{}", hash);
570                    anthropic::batches::BatchRequest { custom_id, params }
571                })
572                .collect::<Vec<_>>();
573
574            let batch_len = batch_requests.len();
575            let batch = anthropic::batches::create_batch(
576                self.http_client.as_ref(),
577                ANTHROPIC_API_URL,
578                &self.api_key,
579                anthropic::batches::CreateBatchRequest {
580                    requests: batch_requests,
581                },
582            )
583            .await
584            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
585
586            {
587                let connection = self.connection.lock().unwrap();
588                connection.with_savepoint("batch_upload", || {
589                    let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
590                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
591                    for hash in &request_hashes {
592                        exec((batch.id.as_str(), hash.as_str()))?;
593                    }
594                    Ok(())
595                })?;
596            }
597
598            total_uploaded += batch_len;
599            log::info!(
600                "Uploaded batch {} with {} requests ({} total)",
601                batch.id,
602                batch_len,
603                total_uploaded
604            );
605
606            all_batch_ids.push(batch.id);
607        }
608
609        if !all_batch_ids.is_empty() {
610            log::info!(
611                "Finished uploading {} batches with {} total requests",
612                all_batch_ids.len(),
613                total_uploaded
614            );
615        }
616
617        Ok(all_batch_ids)
618    }
619
620    fn request_hash(
621        model: &str,
622        max_tokens: u64,
623        messages: &[Message],
624        seed: Option<usize>,
625    ) -> String {
626        let mut hasher = std::hash::DefaultHasher::new();
627        model.hash(&mut hasher);
628        max_tokens.hash(&mut hasher);
629        for msg in messages {
630            message_content_to_string(&msg.content).hash(&mut hasher);
631        }
632        if let Some(seed) = seed {
633            seed.hash(&mut hasher);
634        }
635        let request_hash = hasher.finish();
636        format!("{request_hash:016x}")
637    }
638}
639
640fn message_content_to_string(content: &[RequestContent]) -> String {
641    content
642        .iter()
643        .filter_map(|c| match c {
644            RequestContent::Text { text, .. } => Some(text.clone()),
645            _ => None,
646        })
647        .collect::<Vec<String>>()
648        .join("\n")
649}
650
651pub enum AnthropicClient {
652    // No batching
653    Plain(PlainLlmClient),
654    Batch(BatchingLlmClient),
655    Dummy,
656}
657
658impl AnthropicClient {
659    pub fn plain() -> Result<Self> {
660        Ok(Self::Plain(PlainLlmClient::new()?))
661    }
662
663    pub fn batch(cache_path: &Path) -> Result<Self> {
664        Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
665    }
666
667    #[allow(dead_code)]
668    pub fn dummy() -> Self {
669        Self::Dummy
670    }
671
672    pub async fn generate(
673        &self,
674        model: &str,
675        max_tokens: u64,
676        messages: Vec<Message>,
677        seed: Option<usize>,
678        cache_only: bool,
679    ) -> Result<Option<AnthropicResponse>> {
680        match self {
681            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
682                .generate(model, max_tokens, messages)
683                .await
684                .map(Some),
685            AnthropicClient::Batch(batching_llm_client) => {
686                batching_llm_client
687                    .generate(model, max_tokens, messages, seed, cache_only)
688                    .await
689            }
690            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
691        }
692    }
693
694    #[allow(dead_code)]
695    pub async fn generate_streaming<F>(
696        &self,
697        model: &str,
698        max_tokens: u64,
699        messages: Vec<Message>,
700        on_progress: F,
701    ) -> Result<Option<AnthropicResponse>>
702    where
703        F: FnMut(usize, &str),
704    {
705        match self {
706            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
707                .generate_streaming(model, max_tokens, messages, on_progress)
708                .await
709                .map(Some),
710            AnthropicClient::Batch(_) => {
711                anyhow::bail!("Streaming not supported with batching client")
712            }
713            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
714        }
715    }
716
717    pub async fn sync_batches(&self) -> Result<()> {
718        match self {
719            AnthropicClient::Plain(_) => Ok(()),
720            AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
721            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
722        }
723    }
724
725    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
726        match self {
727            AnthropicClient::Plain(_) => {
728                anyhow::bail!("Import batches is only supported with batching client")
729            }
730            AnthropicClient::Batch(batching_llm_client) => {
731                batching_llm_client.import_batches(batch_ids).await
732            }
733            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
734        }
735    }
736}