collab: Track cache writes/reads in LLM usage (#18834)

Marshall Bowers , Antonio Scandurra , and Antonio created

This PR extends the LLM usage tracking to support tracking usage for
cache writes and reads for Anthropic models.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Antonio <antonio@zed.dev>

Change summary

crates/anthropic/src/anthropic.rs                                      |   4 
crates/collab/migrations_llm/20241007173634_add_cache_token_counts.sql |  11 
crates/collab/src/llm.rs                                               |  80 
crates/collab/src/llm/db/queries/usages.rs                             | 113 
crates/collab/src/llm/db/tables/lifetime_usage.rs                      |   2 
crates/collab/src/llm/db/tables/model.rs                               |   2 
crates/collab/src/llm/db/tables/usage_measure.rs                       |   2 
crates/collab/src/llm/db/tests/usage_tests.rs                          |  62 
crates/collab/src/llm/telemetry.rs                                     |   4 
9 files changed, 241 insertions(+), 39 deletions(-)

Detailed changes

crates/anthropic/src/anthropic.rs 🔗

@@ -521,6 +521,10 @@ pub struct Usage {
     pub input_tokens: Option<u32>,
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub output_tokens: Option<u32>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub cache_creation_input_tokens: Option<u32>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub cache_read_input_tokens: Option<u32>,
 }
 
 #[derive(Debug, Serialize, Deserialize)]

crates/collab/migrations_llm/20241007173634_add_cache_token_counts.sql 🔗

@@ -0,0 +1,11 @@
+alter table models
+    add column price_per_million_cache_creation_input_tokens integer not null default 0,
+    add column price_per_million_cache_read_input_tokens integer not null default 0;
+
+alter table usages
+    add column cache_creation_input_tokens_this_month bigint not null default 0,
+    add column cache_read_input_tokens_this_month bigint not null default 0;
+
+alter table lifetime_usages
+    add column cache_creation_input_tokens bigint not null default 0,
+    add column cache_read_input_tokens bigint not null default 0;

crates/collab/src/llm.rs 🔗

@@ -318,22 +318,31 @@ async fn perform_completion(
             chunks
                 .map(move |event| {
                     let chunk = event?;
-                    let (input_tokens, output_tokens) = match &chunk {
+                    let (
+                        input_tokens,
+                        output_tokens,
+                        cache_creation_input_tokens,
+                        cache_read_input_tokens,
+                    ) = match &chunk {
                         anthropic::Event::MessageStart {
                             message: anthropic::Response { usage, .. },
                         }
                         | anthropic::Event::MessageDelta { usage, .. } => (
                             usage.input_tokens.unwrap_or(0) as usize,
                             usage.output_tokens.unwrap_or(0) as usize,
+                            usage.cache_creation_input_tokens.unwrap_or(0) as usize,
+                            usage.cache_read_input_tokens.unwrap_or(0) as usize,
                         ),
-                        _ => (0, 0),
+                        _ => (0, 0, 0, 0),
                     };
 
-                    anyhow::Ok((
-                        serde_json::to_vec(&chunk).unwrap(),
+                    anyhow::Ok(CompletionChunk {
+                        bytes: serde_json::to_vec(&chunk).unwrap(),
                         input_tokens,
                         output_tokens,
-                    ))
+                        cache_creation_input_tokens,
+                        cache_read_input_tokens,
+                    })
                 })
                 .boxed()
         }
@@ -359,11 +368,13 @@ async fn perform_completion(
                             chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
                         let output_tokens =
                             chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
-                        (
-                            serde_json::to_vec(&chunk).unwrap(),
+                        CompletionChunk {
+                            bytes: serde_json::to_vec(&chunk).unwrap(),
                             input_tokens,
                             output_tokens,
-                        )
+                            cache_creation_input_tokens: 0,
+                            cache_read_input_tokens: 0,
+                        }
                     })
                 })
                 .boxed()
@@ -387,13 +398,13 @@ async fn perform_completion(
                 .map(|event| {
                     event.map(|chunk| {
                         // TODO - implement token counting for Google AI
-                        let input_tokens = 0;
-                        let output_tokens = 0;
-                        (
-                            serde_json::to_vec(&chunk).unwrap(),
-                            input_tokens,
-                            output_tokens,
-                        )
+                        CompletionChunk {
+                            bytes: serde_json::to_vec(&chunk).unwrap(),
+                            input_tokens: 0,
+                            output_tokens: 0,
+                            cache_creation_input_tokens: 0,
+                            cache_read_input_tokens: 0,
+                        }
                     })
                 })
                 .boxed()
@@ -407,6 +418,8 @@ async fn perform_completion(
         model,
         input_tokens: 0,
         output_tokens: 0,
+        cache_creation_input_tokens: 0,
+        cache_read_input_tokens: 0,
         inner_stream: stream,
     })))
 }
@@ -551,6 +564,14 @@ async fn check_usage_limit(
     Ok(())
 }
 
+struct CompletionChunk {
+    bytes: Vec<u8>,
+    input_tokens: usize,
+    output_tokens: usize,
+    cache_creation_input_tokens: usize,
+    cache_read_input_tokens: usize,
+}
+
 struct TokenCountingStream<S> {
     state: Arc<LlmState>,
     claims: LlmTokenClaims,
@@ -558,22 +579,26 @@ struct TokenCountingStream<S> {
     model: String,
     input_tokens: usize,
     output_tokens: usize,
+    cache_creation_input_tokens: usize,
+    cache_read_input_tokens: usize,
     inner_stream: S,
 }
 
 impl<S> Stream for TokenCountingStream<S>
 where
-    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
+    S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
 {
     type Item = Result<Vec<u8>, anyhow::Error>;
 
     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
         match Pin::new(&mut self.inner_stream).poll_next(cx) {
-            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
-                bytes.push(b'\n');
-                self.input_tokens += input_tokens;
-                self.output_tokens += output_tokens;
-                Poll::Ready(Some(Ok(bytes)))
+            Poll::Ready(Some(Ok(mut chunk))) => {
+                chunk.bytes.push(b'\n');
+                self.input_tokens += chunk.input_tokens;
+                self.output_tokens += chunk.output_tokens;
+                self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
+                self.cache_read_input_tokens += chunk.cache_read_input_tokens;
+                Poll::Ready(Some(Ok(chunk.bytes)))
             }
             Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
             Poll::Ready(None) => Poll::Ready(None),
@@ -590,6 +615,8 @@ impl<S> Drop for TokenCountingStream<S> {
         let model = std::mem::take(&mut self.model);
         let input_token_count = self.input_tokens;
         let output_token_count = self.output_tokens;
+        let cache_creation_input_token_count = self.cache_creation_input_tokens;
+        let cache_read_input_token_count = self.cache_read_input_tokens;
         self.state.executor.spawn_detached(async move {
             let usage = state
                 .db
@@ -599,6 +626,8 @@ impl<S> Drop for TokenCountingStream<S> {
                     provider,
                     &model,
                     input_token_count,
+                    cache_creation_input_token_count,
+                    cache_read_input_token_count,
                     output_token_count,
                     Utc::now(),
                 )
@@ -630,11 +659,20 @@ impl<S> Drop for TokenCountingStream<S> {
                             model,
                             provider: provider.to_string(),
                             input_token_count: input_token_count as u64,
+                            cache_creation_input_token_count: cache_creation_input_token_count
+                                as u64,
+                            cache_read_input_token_count: cache_read_input_token_count as u64,
                             output_token_count: output_token_count as u64,
                             requests_this_minute: usage.requests_this_minute as u64,
                             tokens_this_minute: usage.tokens_this_minute as u64,
                             tokens_this_day: usage.tokens_this_day as u64,
                             input_tokens_this_month: usage.input_tokens_this_month as u64,
+                            cache_creation_input_tokens_this_month: usage
+                                .cache_creation_input_tokens_this_month
+                                as u64,
+                            cache_read_input_tokens_this_month: usage
+                                .cache_read_input_tokens_this_month
+                                as u64,
                             output_tokens_this_month: usage.output_tokens_this_month as u64,
                             spending_this_month: usage.spending_this_month as u64,
                             lifetime_spending: usage.lifetime_spending as u64,

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -14,6 +14,8 @@ pub struct Usage {
     pub tokens_this_minute: usize,
     pub tokens_this_day: usize,
     pub input_tokens_this_month: usize,
+    pub cache_creation_input_tokens_this_month: usize,
+    pub cache_read_input_tokens_this_month: usize,
     pub output_tokens_this_month: usize,
     pub spending_this_month: usize,
     pub lifetime_spending: usize,
@@ -160,17 +162,14 @@ impl LlmDatabase {
                 .all(&*tx)
                 .await?;
 
-            let (lifetime_input_tokens, lifetime_output_tokens) = lifetime_usage::Entity::find()
+            let lifetime_usage = lifetime_usage::Entity::find()
                 .filter(
                     lifetime_usage::Column::UserId
                         .eq(user_id)
                         .and(lifetime_usage::Column::ModelId.eq(model.id)),
                 )
                 .one(&*tx)
-                .await?
-                .map_or((0, 0), |usage| {
-                    (usage.input_tokens as usize, usage.output_tokens as usize)
-                });
+                .await?;
 
             let requests_this_minute =
                 self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
@@ -180,18 +179,44 @@ impl LlmDatabase {
                 self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
             let input_tokens_this_month =
                 self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMonth)?;
+            let cache_creation_input_tokens_this_month = self.get_usage_for_measure(
+                &usages,
+                now,
+                UsageMeasure::CacheCreationInputTokensPerMonth,
+            )?;
+            let cache_read_input_tokens_this_month = self.get_usage_for_measure(
+                &usages,
+                now,
+                UsageMeasure::CacheReadInputTokensPerMonth,
+            )?;
             let output_tokens_this_month =
                 self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?;
-            let spending_this_month =
-                calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
-            let lifetime_spending =
-                calculate_spending(model, lifetime_input_tokens, lifetime_output_tokens);
+            let spending_this_month = calculate_spending(
+                model,
+                input_tokens_this_month,
+                cache_creation_input_tokens_this_month,
+                cache_read_input_tokens_this_month,
+                output_tokens_this_month,
+            );
+            let lifetime_spending = if let Some(lifetime_usage) = lifetime_usage {
+                calculate_spending(
+                    model,
+                    lifetime_usage.input_tokens as usize,
+                    lifetime_usage.cache_creation_input_tokens as usize,
+                    lifetime_usage.cache_read_input_tokens as usize,
+                    lifetime_usage.output_tokens as usize,
+                )
+            } else {
+                0
+            };
 
             Ok(Usage {
                 requests_this_minute,
                 tokens_this_minute,
                 tokens_this_day,
                 input_tokens_this_month,
+                cache_creation_input_tokens_this_month,
+                cache_read_input_tokens_this_month,
                 output_tokens_this_month,
                 spending_this_month,
                 lifetime_spending,
@@ -208,6 +233,8 @@ impl LlmDatabase {
         provider: LanguageModelProvider,
         model_name: &str,
         input_token_count: usize,
+        cache_creation_input_tokens: usize,
+        cache_read_input_tokens: usize,
         output_token_count: usize,
         now: DateTimeUtc,
     ) -> Result<Usage> {
@@ -235,6 +262,10 @@ impl LlmDatabase {
                     &tx,
                 )
                 .await?;
+            let total_token_count = input_token_count
+                + cache_read_input_tokens
+                + cache_creation_input_tokens
+                + output_token_count;
             let tokens_this_minute = self
                 .update_usage_for_measure(
                     user_id,
@@ -243,7 +274,7 @@ impl LlmDatabase {
                     &usages,
                     UsageMeasure::TokensPerMinute,
                     now,
-                    input_token_count + output_token_count,
+                    total_token_count,
                     &tx,
                 )
                 .await?;
@@ -255,7 +286,7 @@ impl LlmDatabase {
                     &usages,
                     UsageMeasure::TokensPerDay,
                     now,
-                    input_token_count + output_token_count,
+                    total_token_count,
                     &tx,
                 )
                 .await?;
@@ -271,6 +302,30 @@ impl LlmDatabase {
                     &tx,
                 )
                 .await?;
+            let cache_creation_input_tokens_this_month = self
+                .update_usage_for_measure(
+                    user_id,
+                    is_staff,
+                    model.id,
+                    &usages,
+                    UsageMeasure::CacheCreationInputTokensPerMonth,
+                    now,
+                    cache_creation_input_tokens,
+                    &tx,
+                )
+                .await?;
+            let cache_read_input_tokens_this_month = self
+                .update_usage_for_measure(
+                    user_id,
+                    is_staff,
+                    model.id,
+                    &usages,
+                    UsageMeasure::CacheReadInputTokensPerMonth,
+                    now,
+                    cache_read_input_tokens,
+                    &tx,
+                )
+                .await?;
             let output_tokens_this_month = self
                 .update_usage_for_measure(
                     user_id,
@@ -283,8 +338,13 @@ impl LlmDatabase {
                     &tx,
                 )
                 .await?;
-            let spending_this_month =
-                calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
+            let spending_this_month = calculate_spending(
+                model,
+                input_tokens_this_month,
+                cache_creation_input_tokens_this_month,
+                cache_read_input_tokens_this_month,
+                output_tokens_this_month,
+            );
 
             // Update lifetime usage
             let lifetime_usage = lifetime_usage::Entity::find()
@@ -303,6 +363,12 @@ impl LlmDatabase {
                         input_tokens: ActiveValue::set(
                             usage.input_tokens + input_token_count as i64,
                         ),
+                        cache_creation_input_tokens: ActiveValue::set(
+                            usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
+                        ),
+                        cache_read_input_tokens: ActiveValue::set(
+                            usage.cache_read_input_tokens + cache_read_input_tokens as i64,
+                        ),
                         output_tokens: ActiveValue::set(
                             usage.output_tokens + output_token_count as i64,
                         ),
@@ -327,6 +393,8 @@ impl LlmDatabase {
             let lifetime_spending = calculate_spending(
                 model,
                 lifetime_usage.input_tokens as usize,
+                lifetime_usage.cache_creation_input_tokens as usize,
+                lifetime_usage.cache_read_input_tokens as usize,
                 lifetime_usage.output_tokens as usize,
             );
 
@@ -335,6 +403,8 @@ impl LlmDatabase {
                 tokens_this_minute,
                 tokens_this_day,
                 input_tokens_this_month,
+                cache_creation_input_tokens_this_month,
+                cache_read_input_tokens_this_month,
                 output_tokens_this_month,
                 spending_this_month,
                 lifetime_spending,
@@ -501,13 +571,24 @@ impl LlmDatabase {
 fn calculate_spending(
     model: &model::Model,
     input_tokens_this_month: usize,
+    cache_creation_input_tokens_this_month: usize,
+    cache_read_input_tokens_this_month: usize,
     output_tokens_this_month: usize,
 ) -> usize {
     let input_token_cost =
         input_tokens_this_month * model.price_per_million_input_tokens as usize / 1_000_000;
+    let cache_creation_input_token_cost = cache_creation_input_tokens_this_month
+        * model.price_per_million_cache_creation_input_tokens as usize
+        / 1_000_000;
+    let cache_read_input_token_cost = cache_read_input_tokens_this_month
+        * model.price_per_million_cache_read_input_tokens as usize
+        / 1_000_000;
     let output_token_cost =
         output_tokens_this_month * model.price_per_million_output_tokens as usize / 1_000_000;
-    input_token_cost + output_token_cost
+    input_token_cost
+        + cache_creation_input_token_cost
+        + cache_read_input_token_cost
+        + output_token_cost
 }
 
 const MINUTE_BUCKET_COUNT: usize = 12;
@@ -521,6 +602,8 @@ impl UsageMeasure {
             UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
             UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
             UsageMeasure::InputTokensPerMonth => MONTH_BUCKET_COUNT,
+            UsageMeasure::CacheCreationInputTokensPerMonth => MONTH_BUCKET_COUNT,
+            UsageMeasure::CacheReadInputTokensPerMonth => MONTH_BUCKET_COUNT,
             UsageMeasure::OutputTokensPerMonth => MONTH_BUCKET_COUNT,
         }
     }
@@ -531,6 +614,8 @@ impl UsageMeasure {
             UsageMeasure::TokensPerMinute => Duration::minutes(1),
             UsageMeasure::TokensPerDay => Duration::hours(24),
             UsageMeasure::InputTokensPerMonth => Duration::days(30),
+            UsageMeasure::CacheCreationInputTokensPerMonth => Duration::days(30),
+            UsageMeasure::CacheReadInputTokensPerMonth => Duration::days(30),
             UsageMeasure::OutputTokensPerMonth => Duration::days(30),
         }
     }

crates/collab/src/llm/db/tables/model.rs 🔗

@@ -14,6 +14,8 @@ pub struct Model {
     pub max_tokens_per_minute: i64,
     pub max_tokens_per_day: i64,
     pub price_per_million_input_tokens: i32,
+    pub price_per_million_cache_creation_input_tokens: i32,
+    pub price_per_million_cache_read_input_tokens: i32,
     pub price_per_million_output_tokens: i32,
 }
 

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -33,12 +33,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
     let user_id = UserId::from_proto(123);
 
     let now = t0;
-    db.record_usage(user_id, false, provider, model, 1000, 0, now)
+    db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now)
         .await
         .unwrap();
 
     let now = t0 + Duration::seconds(10);
-    db.record_usage(user_id, false, provider, model, 2000, 0, now)
+    db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now)
         .await
         .unwrap();
 
@@ -50,6 +50,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             tokens_this_minute: 3000,
             tokens_this_day: 3000,
             input_tokens_this_month: 3000,
+            cache_creation_input_tokens_this_month: 0,
+            cache_read_input_tokens_this_month: 0,
             output_tokens_this_month: 0,
             spending_this_month: 0,
             lifetime_spending: 0,
@@ -65,6 +67,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             tokens_this_minute: 2000,
             tokens_this_day: 3000,
             input_tokens_this_month: 3000,
+            cache_creation_input_tokens_this_month: 0,
+            cache_read_input_tokens_this_month: 0,
             output_tokens_this_month: 0,
             spending_this_month: 0,
             lifetime_spending: 0,
@@ -72,7 +76,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
     );
 
     let now = t0 + Duration::seconds(60);
-    db.record_usage(user_id, false, provider, model, 3000, 0, now)
+    db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now)
         .await
         .unwrap();
 
@@ -84,6 +88,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             tokens_this_minute: 5000,
             tokens_this_day: 6000,
             input_tokens_this_month: 6000,
+            cache_creation_input_tokens_this_month: 0,
+            cache_read_input_tokens_this_month: 0,
             output_tokens_this_month: 0,
             spending_this_month: 0,
             lifetime_spending: 0,
@@ -100,13 +106,15 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             tokens_this_minute: 0,
             tokens_this_day: 5000,
             input_tokens_this_month: 6000,
+            cache_creation_input_tokens_this_month: 0,
+            cache_read_input_tokens_this_month: 0,
             output_tokens_this_month: 0,
             spending_this_month: 0,
             lifetime_spending: 0,
         }
     );
 
-    db.record_usage(user_id, false, provider, model, 4000, 0, now)
+    db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now)
         .await
         .unwrap();
 
@@ -118,6 +126,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             tokens_this_minute: 4000,
             tokens_this_day: 9000,
             input_tokens_this_month: 10000,
+            cache_creation_input_tokens_this_month: 0,
+            cache_read_input_tokens_this_month: 0,
             output_tokens_this_month: 0,
             spending_this_month: 0,
             lifetime_spending: 0,
@@ -134,6 +144,50 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             tokens_this_minute: 0,
             tokens_this_day: 0,
             input_tokens_this_month: 9000,
+            cache_creation_input_tokens_this_month: 0,
+            cache_read_input_tokens_this_month: 0,
+            output_tokens_this_month: 0,
+            spending_this_month: 0,
+            lifetime_spending: 0,
+        }
+    );
+
+    // Test cache creation input tokens
+    db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
+        .await
+        .unwrap();
+
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 1,
+            tokens_this_minute: 1500,
+            tokens_this_day: 1500,
+            input_tokens_this_month: 10000,
+            cache_creation_input_tokens_this_month: 500,
+            cache_read_input_tokens_this_month: 0,
+            output_tokens_this_month: 0,
+            spending_this_month: 0,
+            lifetime_spending: 0,
+        }
+    );
+
+    // Test cache read input tokens
+    db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now)
+        .await
+        .unwrap();
+
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 2,
+            tokens_this_minute: 2800,
+            tokens_this_day: 2800,
+            input_tokens_this_month: 11000,
+            cache_creation_input_tokens_this_month: 500,
+            cache_read_input_tokens_this_month: 300,
             output_tokens_this_month: 0,
             spending_this_month: 0,
             lifetime_spending: 0,

crates/collab/src/llm/telemetry.rs 🔗

@@ -12,11 +12,15 @@ pub struct LlmUsageEventRow {
     pub model: String,
     pub provider: String,
     pub input_token_count: u64,
+    pub cache_creation_input_token_count: u64,
+    pub cache_read_input_token_count: u64,
     pub output_token_count: u64,
     pub requests_this_minute: u64,
     pub tokens_this_minute: u64,
     pub tokens_this_day: u64,
     pub input_tokens_this_month: u64,
+    pub cache_creation_input_tokens_this_month: u64,
+    pub cache_read_input_tokens_this_month: u64,
     pub output_tokens_this_month: u64,
     pub spending_this_month: u64,
     pub lifetime_spending: u64,