anthropic_client.rs

  1use anthropic::{
  2    ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
  3    Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
  4    stream_completion,
  5};
  6use anyhow::Result;
  7use futures::StreamExt as _;
  8use http_client::HttpClient;
  9use indoc::indoc;
 10use reqwest_client::ReqwestClient;
 11use sqlez::bindable::Bind;
 12use sqlez::bindable::StaticColumnCount;
 13use sqlez_macros::sql;
 14use std::hash::Hash;
 15use std::hash::Hasher;
 16use std::path::Path;
 17use std::sync::{Arc, Mutex};
 18
 19pub struct PlainLlmClient {
 20    pub http_client: Arc<dyn HttpClient>,
 21    pub api_key: String,
 22}
 23
 24impl PlainLlmClient {
 25    pub fn new() -> Result<Self> {
 26        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 27        let api_key = std::env::var("ANTHROPIC_API_KEY")
 28            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 29        Ok(Self {
 30            http_client,
 31            api_key,
 32        })
 33    }
 34
 35    pub async fn generate(
 36        &self,
 37        model: &str,
 38        max_tokens: u64,
 39        messages: Vec<Message>,
 40    ) -> Result<AnthropicResponse> {
 41        let request = AnthropicRequest {
 42            model: model.to_string(),
 43            max_tokens,
 44            messages,
 45            tools: Vec::new(),
 46            thinking: None,
 47            tool_choice: None,
 48            system: None,
 49            metadata: None,
 50            output_config: None,
 51            stop_sequences: Vec::new(),
 52            temperature: None,
 53            top_k: None,
 54            top_p: None,
 55        };
 56
 57        let response = non_streaming_completion(
 58            self.http_client.as_ref(),
 59            ANTHROPIC_API_URL,
 60            &self.api_key,
 61            request,
 62            None,
 63        )
 64        .await
 65        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 66
 67        Ok(response)
 68    }
 69
 70    pub async fn generate_streaming<F>(
 71        &self,
 72        model: &str,
 73        max_tokens: u64,
 74        messages: Vec<Message>,
 75        mut on_progress: F,
 76    ) -> Result<AnthropicResponse>
 77    where
 78        F: FnMut(usize, &str),
 79    {
 80        let request = AnthropicRequest {
 81            model: model.to_string(),
 82            max_tokens,
 83            messages,
 84            tools: Vec::new(),
 85            thinking: None,
 86            tool_choice: None,
 87            system: None,
 88            metadata: None,
 89            output_config: None,
 90            stop_sequences: Vec::new(),
 91            temperature: None,
 92            top_k: None,
 93            top_p: None,
 94        };
 95
 96        let mut stream = stream_completion(
 97            self.http_client.as_ref(),
 98            ANTHROPIC_API_URL,
 99            &self.api_key,
100            request,
101            None,
102        )
103        .await
104        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
105
106        let mut response: Option<AnthropicResponse> = None;
107        let mut text_content = String::new();
108
109        while let Some(event_result) = stream.next().await {
110            let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
111
112            match event {
113                Event::MessageStart { message } => {
114                    response = Some(message);
115                }
116                Event::ContentBlockDelta { delta, .. } => {
117                    if let anthropic::ContentDelta::TextDelta { text } = delta {
118                        text_content.push_str(&text);
119                        on_progress(text_content.len(), &text_content);
120                    }
121                }
122                _ => {}
123            }
124        }
125
126        let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
127
128        if response.content.is_empty() && !text_content.is_empty() {
129            response
130                .content
131                .push(ResponseContent::Text { text: text_content });
132        }
133
134        Ok(response)
135    }
136}
137
138pub struct BatchingLlmClient {
139    connection: Mutex<sqlez::connection::Connection>,
140    http_client: Arc<dyn HttpClient>,
141    api_key: String,
142}
143
144struct CacheRow {
145    request_hash: String,
146    request: Option<String>,
147    response: Option<String>,
148    batch_id: Option<String>,
149}
150
151impl StaticColumnCount for CacheRow {
152    fn column_count() -> usize {
153        4
154    }
155}
156
157impl Bind for CacheRow {
158    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
159        let next_index = statement.bind(&self.request_hash, start_index)?;
160        let next_index = statement.bind(&self.request, next_index)?;
161        let next_index = statement.bind(&self.response, next_index)?;
162        let next_index = statement.bind(&self.batch_id, next_index)?;
163        Ok(next_index)
164    }
165}
166
167#[derive(serde::Serialize, serde::Deserialize)]
168struct SerializableRequest {
169    model: String,
170    max_tokens: u64,
171    messages: Vec<SerializableMessage>,
172}
173
174#[derive(serde::Serialize, serde::Deserialize)]
175struct SerializableMessage {
176    role: String,
177    content: String,
178}
179
180impl BatchingLlmClient {
181    fn new(cache_path: &Path) -> Result<Self> {
182        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
183        let api_key = std::env::var("ANTHROPIC_API_KEY")
184            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
185
186        let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
187        let mut statement = sqlez::statement::Statement::prepare(
188            &connection,
189            indoc! {"
190                CREATE TABLE IF NOT EXISTS cache (
191                    request_hash TEXT PRIMARY KEY,
192                    request TEXT,
193                    response TEXT,
194                    batch_id TEXT
195                );
196                "},
197        )?;
198        statement.exec()?;
199        drop(statement);
200
201        Ok(Self {
202            connection: Mutex::new(connection),
203            http_client,
204            api_key,
205        })
206    }
207
208    pub fn lookup(
209        &self,
210        model: &str,
211        max_tokens: u64,
212        messages: &[Message],
213        seed: Option<usize>,
214    ) -> Result<Option<AnthropicResponse>> {
215        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
216        let connection = self.connection.lock().unwrap();
217        let response: Vec<String> = connection.select_bound(
218            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
219        )?(request_hash_str.as_str())?;
220        Ok(response
221            .into_iter()
222            .next()
223            .and_then(|text| serde_json::from_str(&text).ok()))
224    }
225
226    pub fn mark_for_batch(
227        &self,
228        model: &str,
229        max_tokens: u64,
230        messages: &[Message],
231        seed: Option<usize>,
232    ) -> Result<()> {
233        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
234
235        let serializable_messages: Vec<SerializableMessage> = messages
236            .iter()
237            .map(|msg| SerializableMessage {
238                role: match msg.role {
239                    Role::User => "user".to_string(),
240                    Role::Assistant => "assistant".to_string(),
241                },
242                content: message_content_to_string(&msg.content),
243            })
244            .collect();
245
246        let serializable_request = SerializableRequest {
247            model: model.to_string(),
248            max_tokens,
249            messages: serializable_messages,
250        };
251
252        let request = Some(serde_json::to_string(&serializable_request)?);
253        let cache_row = CacheRow {
254            request_hash,
255            request,
256            response: None,
257            batch_id: None,
258        };
259        let connection = self.connection.lock().unwrap();
260        connection.exec_bound::<CacheRow>(sql!(
261            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
262            cache_row,
263        )
264    }
265
266    async fn generate(
267        &self,
268        model: &str,
269        max_tokens: u64,
270        messages: Vec<Message>,
271        seed: Option<usize>,
272        cache_only: bool,
273    ) -> Result<Option<AnthropicResponse>> {
274        let response = self.lookup(model, max_tokens, &messages, seed)?;
275        if let Some(response) = response {
276            return Ok(Some(response));
277        }
278
279        if !cache_only {
280            self.mark_for_batch(model, max_tokens, &messages, seed)?;
281        }
282
283        Ok(None)
284    }
285
286    /// Uploads pending requests as batches (chunked to 16k each); downloads finished batches if any.
287    async fn sync_batches(&self) -> Result<()> {
288        let _batch_ids = self.upload_pending_requests().await?;
289        self.download_finished_batches().await
290    }
291
292    /// Import batch results from external batch IDs (useful for recovering after database loss)
293    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
294        for batch_id in batch_ids {
295            log::info!("Importing batch {}", batch_id);
296
297            let batch_status = anthropic::batches::retrieve_batch(
298                self.http_client.as_ref(),
299                ANTHROPIC_API_URL,
300                &self.api_key,
301                batch_id,
302            )
303            .await
304            .map_err(|e| anyhow::anyhow!("Failed to retrieve batch {}: {:?}", batch_id, e))?;
305
306            log::info!(
307                "Batch {} status: {}",
308                batch_id,
309                batch_status.processing_status
310            );
311
312            if batch_status.processing_status != "ended" {
313                log::warn!(
314                    "Batch {} is not finished (status: {}), skipping",
315                    batch_id,
316                    batch_status.processing_status
317                );
318                continue;
319            }
320
321            let results = anthropic::batches::retrieve_batch_results(
322                self.http_client.as_ref(),
323                ANTHROPIC_API_URL,
324                &self.api_key,
325                batch_id,
326            )
327            .await
328            .map_err(|e| {
329                anyhow::anyhow!("Failed to retrieve batch results for {}: {:?}", batch_id, e)
330            })?;
331
332            let mut updates: Vec<(String, String, String)> = Vec::new();
333            let mut success_count = 0;
334            let mut error_count = 0;
335
336            for result in results {
337                let request_hash = result
338                    .custom_id
339                    .strip_prefix("req_hash_")
340                    .unwrap_or(&result.custom_id)
341                    .to_string();
342
343                match result.result {
344                    anthropic::batches::BatchResult::Succeeded { message } => {
345                        let response_json = serde_json::to_string(&message)?;
346                        updates.push((request_hash, response_json, batch_id.clone()));
347                        success_count += 1;
348                    }
349                    anthropic::batches::BatchResult::Errored { error } => {
350                        log::error!(
351                            "Batch request {} failed: {}: {}",
352                            request_hash,
353                            error.error.error_type,
354                            error.error.message
355                        );
356                        let error_json = serde_json::json!({
357                            "error": {
358                                "type": error.error.error_type,
359                                "message": error.error.message
360                            }
361                        })
362                        .to_string();
363                        updates.push((request_hash, error_json, batch_id.clone()));
364                        error_count += 1;
365                    }
366                    anthropic::batches::BatchResult::Canceled => {
367                        log::warn!("Batch request {} was canceled", request_hash);
368                        error_count += 1;
369                    }
370                    anthropic::batches::BatchResult::Expired => {
371                        log::warn!("Batch request {} expired", request_hash);
372                        error_count += 1;
373                    }
374                }
375            }
376
377            let connection = self.connection.lock().unwrap();
378            connection.with_savepoint("batch_import", || {
379                // Use INSERT OR REPLACE to handle both new entries and updating existing ones
380                let q = sql!(
381                    INSERT OR REPLACE INTO cache(request_hash, request, response, batch_id)
382                    VALUES (?, (SELECT request FROM cache WHERE request_hash = ?), ?, ?)
383                );
384                let mut exec = connection.exec_bound::<(&str, &str, &str, &str)>(q)?;
385                for (request_hash, response_json, batch_id) in &updates {
386                    exec((
387                        request_hash.as_str(),
388                        request_hash.as_str(),
389                        response_json.as_str(),
390                        batch_id.as_str(),
391                    ))?;
392                }
393                Ok(())
394            })?;
395
396            log::info!(
397                "Imported batch {}: {} successful, {} errors",
398                batch_id,
399                success_count,
400                error_count
401            );
402        }
403
404        Ok(())
405    }
406
407    async fn download_finished_batches(&self) -> Result<()> {
408        let batch_ids: Vec<String> = {
409            let connection = self.connection.lock().unwrap();
410            let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
411            connection.select(q)?()?
412        };
413
414        for batch_id in &batch_ids {
415            let batch_status = anthropic::batches::retrieve_batch(
416                self.http_client.as_ref(),
417                ANTHROPIC_API_URL,
418                &self.api_key,
419                &batch_id,
420            )
421            .await
422            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
423
424            log::info!(
425                "Batch {} status: {}",
426                batch_id,
427                batch_status.processing_status
428            );
429
430            if batch_status.processing_status == "ended" {
431                let results = anthropic::batches::retrieve_batch_results(
432                    self.http_client.as_ref(),
433                    ANTHROPIC_API_URL,
434                    &self.api_key,
435                    &batch_id,
436                )
437                .await
438                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
439
440                let mut updates: Vec<(String, String)> = Vec::new();
441                let mut success_count = 0;
442                for result in results {
443                    let request_hash = result
444                        .custom_id
445                        .strip_prefix("req_hash_")
446                        .unwrap_or(&result.custom_id)
447                        .to_string();
448
449                    match result.result {
450                        anthropic::batches::BatchResult::Succeeded { message } => {
451                            let response_json = serde_json::to_string(&message)?;
452                            updates.push((response_json, request_hash));
453                            success_count += 1;
454                        }
455                        anthropic::batches::BatchResult::Errored { error } => {
456                            log::error!(
457                                "Batch request {} failed: {}: {}",
458                                request_hash,
459                                error.error.error_type,
460                                error.error.message
461                            );
462                            let error_json = serde_json::json!({
463                                "error": {
464                                    "type": error.error.error_type,
465                                    "message": error.error.message
466                                }
467                            })
468                            .to_string();
469                            updates.push((error_json, request_hash));
470                        }
471                        anthropic::batches::BatchResult::Canceled => {
472                            log::warn!("Batch request {} was canceled", request_hash);
473                            let error_json = serde_json::json!({
474                                "error": {
475                                    "type": "canceled",
476                                    "message": "Batch request was canceled"
477                                }
478                            })
479                            .to_string();
480                            updates.push((error_json, request_hash));
481                        }
482                        anthropic::batches::BatchResult::Expired => {
483                            log::warn!("Batch request {} expired", request_hash);
484                            let error_json = serde_json::json!({
485                                "error": {
486                                    "type": "expired",
487                                    "message": "Batch request expired"
488                                }
489                            })
490                            .to_string();
491                            updates.push((error_json, request_hash));
492                        }
493                    }
494                }
495
496                let connection = self.connection.lock().unwrap();
497                connection.with_savepoint("batch_download", || {
498                    let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
499                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
500                    for (response_json, request_hash) in &updates {
501                        exec((response_json.as_str(), request_hash.as_str()))?;
502                    }
503                    Ok(())
504                })?;
505                log::info!("Downloaded {} successful requests", success_count);
506            }
507        }
508
509        Ok(())
510    }
511
512    async fn upload_pending_requests(&self) -> Result<Vec<String>> {
513        const BATCH_CHUNK_SIZE: i32 = 16_000;
514        let mut all_batch_ids = Vec::new();
515        let mut total_uploaded = 0;
516
517        loop {
518            let rows: Vec<(String, String)> = {
519                let connection = self.connection.lock().unwrap();
520                let q = sql!(
521                    SELECT request_hash, request FROM cache
522                    WHERE batch_id IS NULL AND response IS NULL
523                    LIMIT ?
524                );
525                connection.select_bound(q)?(BATCH_CHUNK_SIZE)?
526            };
527
528            if rows.is_empty() {
529                break;
530            }
531
532            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
533
534            let batch_requests = rows
535                .iter()
536                .map(|(hash, request_str)| {
537                    let serializable_request: SerializableRequest =
538                        serde_json::from_str(&request_str).unwrap();
539
540                    let messages: Vec<Message> = serializable_request
541                        .messages
542                        .into_iter()
543                        .map(|msg| Message {
544                            role: match msg.role.as_str() {
545                                "user" => Role::User,
546                                "assistant" => Role::Assistant,
547                                _ => Role::User,
548                            },
549                            content: vec![RequestContent::Text {
550                                text: msg.content,
551                                cache_control: None,
552                            }],
553                        })
554                        .collect();
555
556                    let params = AnthropicRequest {
557                        model: serializable_request.model,
558                        max_tokens: serializable_request.max_tokens,
559                        messages,
560                        tools: Vec::new(),
561                        thinking: None,
562                        tool_choice: None,
563                        system: None,
564                        metadata: None,
565                        output_config: None,
566                        stop_sequences: Vec::new(),
567                        temperature: None,
568                        top_k: None,
569                        top_p: None,
570                    };
571
572                    let custom_id = format!("req_hash_{}", hash);
573                    anthropic::batches::BatchRequest { custom_id, params }
574                })
575                .collect::<Vec<_>>();
576
577            let batch_len = batch_requests.len();
578            let batch = anthropic::batches::create_batch(
579                self.http_client.as_ref(),
580                ANTHROPIC_API_URL,
581                &self.api_key,
582                anthropic::batches::CreateBatchRequest {
583                    requests: batch_requests,
584                },
585            )
586            .await
587            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
588
589            {
590                let connection = self.connection.lock().unwrap();
591                connection.with_savepoint("batch_upload", || {
592                    let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
593                    let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
594                    for hash in &request_hashes {
595                        exec((batch.id.as_str(), hash.as_str()))?;
596                    }
597                    Ok(())
598                })?;
599            }
600
601            total_uploaded += batch_len;
602            log::info!(
603                "Uploaded batch {} with {} requests ({} total)",
604                batch.id,
605                batch_len,
606                total_uploaded
607            );
608
609            all_batch_ids.push(batch.id);
610        }
611
612        if !all_batch_ids.is_empty() {
613            log::info!(
614                "Finished uploading {} batches with {} total requests",
615                all_batch_ids.len(),
616                total_uploaded
617            );
618        }
619
620        Ok(all_batch_ids)
621    }
622
623    fn request_hash(
624        model: &str,
625        max_tokens: u64,
626        messages: &[Message],
627        seed: Option<usize>,
628    ) -> String {
629        let mut hasher = std::hash::DefaultHasher::new();
630        model.hash(&mut hasher);
631        max_tokens.hash(&mut hasher);
632        for msg in messages {
633            message_content_to_string(&msg.content).hash(&mut hasher);
634        }
635        if let Some(seed) = seed {
636            seed.hash(&mut hasher);
637        }
638        let request_hash = hasher.finish();
639        format!("{request_hash:016x}")
640    }
641}
642
643fn message_content_to_string(content: &[RequestContent]) -> String {
644    content
645        .iter()
646        .filter_map(|c| match c {
647            RequestContent::Text { text, .. } => Some(text.clone()),
648            _ => None,
649        })
650        .collect::<Vec<String>>()
651        .join("\n")
652}
653
654pub enum AnthropicClient {
655    // No batching
656    Plain(PlainLlmClient),
657    Batch(BatchingLlmClient),
658    Dummy,
659}
660
661impl AnthropicClient {
662    pub fn plain() -> Result<Self> {
663        Ok(Self::Plain(PlainLlmClient::new()?))
664    }
665
666    pub fn batch(cache_path: &Path) -> Result<Self> {
667        Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
668    }
669
670    #[allow(dead_code)]
671    pub fn dummy() -> Self {
672        Self::Dummy
673    }
674
675    pub async fn generate(
676        &self,
677        model: &str,
678        max_tokens: u64,
679        messages: Vec<Message>,
680        seed: Option<usize>,
681        cache_only: bool,
682    ) -> Result<Option<AnthropicResponse>> {
683        match self {
684            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
685                .generate(model, max_tokens, messages)
686                .await
687                .map(Some),
688            AnthropicClient::Batch(batching_llm_client) => {
689                batching_llm_client
690                    .generate(model, max_tokens, messages, seed, cache_only)
691                    .await
692            }
693            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
694        }
695    }
696
697    #[allow(dead_code)]
698    pub async fn generate_streaming<F>(
699        &self,
700        model: &str,
701        max_tokens: u64,
702        messages: Vec<Message>,
703        on_progress: F,
704    ) -> Result<Option<AnthropicResponse>>
705    where
706        F: FnMut(usize, &str),
707    {
708        match self {
709            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
710                .generate_streaming(model, max_tokens, messages, on_progress)
711                .await
712                .map(Some),
713            AnthropicClient::Batch(_) => {
714                anyhow::bail!("Streaming not supported with batching client")
715            }
716            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
717        }
718    }
719
720    pub async fn sync_batches(&self) -> Result<()> {
721        match self {
722            AnthropicClient::Plain(_) => Ok(()),
723            AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
724            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
725        }
726    }
727
728    pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
729        match self {
730            AnthropicClient::Plain(_) => {
731                anyhow::bail!("Import batches is only supported with batching client")
732            }
733            AnthropicClient::Batch(batching_llm_client) => {
734                batching_llm_client.import_batches(batch_ids).await
735            }
736            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
737        }
738    }
739}