llm_client.rs

  1use anthropic::{
  2    ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
  3    Response as AnthropicResponse, Role, non_streaming_completion,
  4};
  5use anyhow::Result;
  6use http_client::HttpClient;
  7use indoc::indoc;
  8use sqlez::bindable::Bind;
  9use sqlez::bindable::StaticColumnCount;
 10use sqlez_macros::sql;
 11use std::hash::Hash;
 12use std::hash::Hasher;
 13use std::sync::Arc;
 14
 15pub struct PlainLlmClient {
 16    http_client: Arc<dyn HttpClient>,
 17    api_key: String,
 18}
 19
 20impl PlainLlmClient {
 21    fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
 22        let api_key = std::env::var("ANTHROPIC_API_KEY")
 23            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 24        Ok(Self {
 25            http_client,
 26            api_key,
 27        })
 28    }
 29
 30    async fn generate(
 31        &self,
 32        model: String,
 33        max_tokens: u64,
 34        messages: Vec<Message>,
 35    ) -> Result<AnthropicResponse> {
 36        let request = AnthropicRequest {
 37            model,
 38            max_tokens,
 39            messages,
 40            tools: Vec::new(),
 41            thinking: None,
 42            tool_choice: None,
 43            system: None,
 44            metadata: None,
 45            stop_sequences: Vec::new(),
 46            temperature: None,
 47            top_k: None,
 48            top_p: None,
 49        };
 50
 51        let response = non_streaming_completion(
 52            self.http_client.as_ref(),
 53            ANTHROPIC_API_URL,
 54            &self.api_key,
 55            request,
 56            None,
 57        )
 58        .await
 59        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
 60
 61        Ok(response)
 62    }
 63}
 64
 65pub struct BatchingLlmClient {
 66    connection: sqlez::connection::Connection,
 67    http_client: Arc<dyn HttpClient>,
 68    api_key: String,
 69}
 70
 71struct CacheRow {
 72    request_hash: String,
 73    request: Option<String>,
 74    response: Option<String>,
 75    batch_id: Option<String>,
 76}
 77
 78impl StaticColumnCount for CacheRow {
 79    fn column_count() -> usize {
 80        4
 81    }
 82}
 83
 84impl Bind for CacheRow {
 85    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
 86        let next_index = statement.bind(&self.request_hash, start_index)?;
 87        let next_index = statement.bind(&self.request, next_index)?;
 88        let next_index = statement.bind(&self.response, next_index)?;
 89        let next_index = statement.bind(&self.batch_id, next_index)?;
 90        Ok(next_index)
 91    }
 92}
 93
 94#[derive(serde::Serialize, serde::Deserialize)]
 95struct SerializableRequest {
 96    model: String,
 97    max_tokens: u64,
 98    messages: Vec<SerializableMessage>,
 99}
100
101#[derive(serde::Serialize, serde::Deserialize)]
102struct SerializableMessage {
103    role: String,
104    content: String,
105}
106
107impl BatchingLlmClient {
108    fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
109        let api_key = std::env::var("ANTHROPIC_API_KEY")
110            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
111
112        let connection = sqlez::connection::Connection::open_file(&cache_path);
113        let mut statement = sqlez::statement::Statement::prepare(
114            &connection,
115            indoc! {"
116                CREATE TABLE IF NOT EXISTS cache (
117                    request_hash TEXT PRIMARY KEY,
118                    request TEXT,
119                    response TEXT,
120                    batch_id TEXT
121                );
122                "},
123        )?;
124        statement.exec()?;
125        drop(statement);
126
127        Ok(Self {
128            connection,
129            http_client,
130            api_key,
131        })
132    }
133
134    pub fn lookup(
135        &self,
136        model: &str,
137        max_tokens: u64,
138        messages: &[Message],
139    ) -> Result<Option<AnthropicResponse>> {
140        let request_hash_str = Self::request_hash(model, max_tokens, messages);
141        let response: Vec<String> = self.connection.select_bound(
142            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
143        )?(request_hash_str.as_str())?;
144        Ok(response
145            .into_iter()
146            .next()
147            .and_then(|text| serde_json::from_str(&text).ok()))
148    }
149
150    pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
151        let request_hash = Self::request_hash(model, max_tokens, messages);
152
153        let serializable_messages: Vec<SerializableMessage> = messages
154            .iter()
155            .map(|msg| SerializableMessage {
156                role: match msg.role {
157                    Role::User => "user".to_string(),
158                    Role::Assistant => "assistant".to_string(),
159                },
160                content: message_content_to_string(&msg.content),
161            })
162            .collect();
163
164        let serializable_request = SerializableRequest {
165            model: model.to_string(),
166            max_tokens,
167            messages: serializable_messages,
168        };
169
170        let request = Some(serde_json::to_string(&serializable_request)?);
171        let cache_row = CacheRow {
172            request_hash,
173            request,
174            response: None,
175            batch_id: None,
176        };
177        self.connection.exec_bound(sql!(
178            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
179            cache_row,
180        )
181    }
182
183    async fn generate(
184        &self,
185        model: String,
186        max_tokens: u64,
187        messages: Vec<Message>,
188    ) -> Result<Option<AnthropicResponse>> {
189        let response = self.lookup(&model, max_tokens, &messages)?;
190        if let Some(response) = response {
191            return Ok(Some(response));
192        }
193
194        self.mark_for_batch(&model, max_tokens, &messages)?;
195
196        Ok(None)
197    }
198
199    /// Uploads pending requests as a new batch; downloads finished batches if any.
200    async fn sync_batches(&self) -> Result<()> {
201        self.upload_pending_requests().await?;
202        self.download_finished_batches().await
203    }
204
205    async fn download_finished_batches(&self) -> Result<()> {
206        let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
207        let batch_ids: Vec<String> = self.connection.select(q)?()?;
208
209        for batch_id in batch_ids {
210            let batch_status = anthropic::batches::retrieve_batch(
211                self.http_client.as_ref(),
212                ANTHROPIC_API_URL,
213                &self.api_key,
214                &batch_id,
215            )
216            .await
217            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
218
219            log::info!(
220                "Batch {} status: {}",
221                batch_id,
222                batch_status.processing_status
223            );
224
225            if batch_status.processing_status == "ended" {
226                let results = anthropic::batches::retrieve_batch_results(
227                    self.http_client.as_ref(),
228                    ANTHROPIC_API_URL,
229                    &self.api_key,
230                    &batch_id,
231                )
232                .await
233                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
234
235                let mut success_count = 0;
236                for result in results {
237                    let request_hash = result
238                        .custom_id
239                        .strip_prefix("req_hash_")
240                        .unwrap_or(&result.custom_id)
241                        .to_string();
242
243                    match result.result {
244                        anthropic::batches::BatchResult::Succeeded { message } => {
245                            let response_json = serde_json::to_string(&message)?;
246                            let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
247                            self.connection.exec_bound(q)?((response_json, request_hash))?;
248                            success_count += 1;
249                        }
250                        anthropic::batches::BatchResult::Errored { error } => {
251                            log::error!("Batch request {} failed: {:?}", request_hash, error);
252                        }
253                        anthropic::batches::BatchResult::Canceled => {
254                            log::warn!("Batch request {} was canceled", request_hash);
255                        }
256                        anthropic::batches::BatchResult::Expired => {
257                            log::warn!("Batch request {} expired", request_hash);
258                        }
259                    }
260                }
261                log::info!("Uploaded {} successful requests", success_count);
262            }
263        }
264
265        Ok(())
266    }
267
268    async fn upload_pending_requests(&self) -> Result<String> {
269        let q = sql!(
270        SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
271        );
272
273        let rows: Vec<(String, String)> = self.connection.select(q)?()?;
274
275        if rows.is_empty() {
276            return Ok(String::new());
277        }
278
279        let batch_requests = rows
280            .iter()
281            .map(|(hash, request_str)| {
282                let serializable_request: SerializableRequest =
283                    serde_json::from_str(&request_str).unwrap();
284
285                let messages: Vec<Message> = serializable_request
286                    .messages
287                    .into_iter()
288                    .map(|msg| Message {
289                        role: match msg.role.as_str() {
290                            "user" => Role::User,
291                            "assistant" => Role::Assistant,
292                            _ => Role::User,
293                        },
294                        content: vec![RequestContent::Text {
295                            text: msg.content,
296                            cache_control: None,
297                        }],
298                    })
299                    .collect();
300
301                let params = AnthropicRequest {
302                    model: serializable_request.model,
303                    max_tokens: serializable_request.max_tokens,
304                    messages,
305                    tools: Vec::new(),
306                    thinking: None,
307                    tool_choice: None,
308                    system: None,
309                    metadata: None,
310                    stop_sequences: Vec::new(),
311                    temperature: None,
312                    top_k: None,
313                    top_p: None,
314                };
315
316                let custom_id = format!("req_hash_{}", hash);
317                anthropic::batches::BatchRequest { custom_id, params }
318            })
319            .collect::<Vec<_>>();
320
321        let batch_len = batch_requests.len();
322        let batch = anthropic::batches::create_batch(
323            self.http_client.as_ref(),
324            ANTHROPIC_API_URL,
325            &self.api_key,
326            anthropic::batches::CreateBatchRequest {
327                requests: batch_requests,
328            },
329        )
330        .await
331        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
332
333        let q = sql!(
334            UPDATE cache SET batch_id = ? WHERE batch_id is NULL
335        );
336        self.connection.exec_bound(q)?(batch.id.as_str())?;
337
338        log::info!("Uploaded batch with {} requests", batch_len);
339
340        Ok(batch.id)
341    }
342
343    fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
344        let mut hasher = std::hash::DefaultHasher::new();
345        model.hash(&mut hasher);
346        max_tokens.hash(&mut hasher);
347        for msg in messages {
348            message_content_to_string(&msg.content).hash(&mut hasher);
349        }
350        let request_hash = hasher.finish();
351        format!("{request_hash:016x}")
352    }
353}
354
355fn message_content_to_string(content: &[RequestContent]) -> String {
356    content
357        .iter()
358        .filter_map(|c| match c {
359            RequestContent::Text { text, .. } => Some(text.clone()),
360            _ => None,
361        })
362        .collect::<Vec<String>>()
363        .join("\n")
364}
365
366pub enum LlmClient {
367    // No batching
368    Plain(PlainLlmClient),
369    Batch(BatchingLlmClient),
370    Dummy,
371}
372
373impl LlmClient {
374    pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
375        Ok(Self::Plain(PlainLlmClient::new(http_client)?))
376    }
377
378    pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
379        Ok(Self::Batch(BatchingLlmClient::new(
380            cache_path,
381            http_client,
382        )?))
383    }
384
385    #[allow(dead_code)]
386    pub fn dummy() -> Self {
387        Self::Dummy
388    }
389
390    pub async fn generate(
391        &self,
392        model: String,
393        max_tokens: u64,
394        messages: Vec<Message>,
395    ) -> Result<Option<AnthropicResponse>> {
396        match self {
397            LlmClient::Plain(plain_llm_client) => plain_llm_client
398                .generate(model, max_tokens, messages)
399                .await
400                .map(Some),
401            LlmClient::Batch(batching_llm_client) => {
402                batching_llm_client
403                    .generate(model, max_tokens, messages)
404                    .await
405            }
406            LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
407        }
408    }
409
410    pub async fn sync_batches(&self) -> Result<()> {
411        match self {
412            LlmClient::Plain(_) => Ok(()),
413            LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
414            LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
415        }
416    }
417}