collab: Track input and output tokens per minute separately (#28097)

Marshall Bowers created

This PR adds tracking for input and output tokens per minute separately
from the current aggregate tokens per minute.

We are not yet rate-limiting based on these measures.

Release Notes:

- N/A

Change summary

crates/collab/src/llm.rs                         | 14 +++
crates/collab/src/llm/db/queries/usages.rs       | 73 +++++++++++++++++
crates/collab/src/llm/db/tables/usage_measure.rs |  2 
crates/collab/src/llm/db/tests/usage_tests.rs    | 14 +++
4 files changed, 101 insertions(+), 2 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -499,6 +499,10 @@ async fn check_usage_limit(
         model.max_requests_per_minute as usize / users_in_recent_minutes;
     let per_user_max_tokens_per_minute =
         model.max_tokens_per_minute as usize / users_in_recent_minutes;
+    let per_user_max_input_tokens_per_minute =
+        model.max_input_tokens_per_minute as usize / users_in_recent_minutes;
+    let per_user_max_output_tokens_per_minute =
+        model.max_output_tokens_per_minute as usize / users_in_recent_minutes;
     let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
 
     let usage = state
@@ -529,6 +533,8 @@ async fn check_usage_limit(
             let resource = match usage_measure {
                 UsageMeasure::RequestsPerMinute => "requests_per_minute",
                 UsageMeasure::TokensPerMinute => "tokens_per_minute",
+                UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute",
+                UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute",
                 UsageMeasure::TokensPerDay => "tokens_per_day",
             };
 
@@ -542,11 +548,15 @@ async fn check_usage_limit(
                 model = model.name,
                 requests_this_minute = usage.requests_this_minute,
                 tokens_this_minute = usage.tokens_this_minute,
+                input_tokens_this_minute = usage.input_tokens_this_minute,
+                output_tokens_this_minute = usage.output_tokens_this_minute,
                 tokens_this_day = usage.tokens_this_day,
                 users_in_recent_minutes = users_in_recent_minutes,
                 users_in_recent_days = users_in_recent_days,
                 max_requests_per_minute = per_user_max_requests_per_minute,
                 max_tokens_per_minute = per_user_max_tokens_per_minute,
+                max_input_tokens_per_minute = per_user_max_input_tokens_per_minute,
+                max_output_tokens_per_minute = per_user_max_output_tokens_per_minute,
                 max_tokens_per_day = per_user_max_tokens_per_day,
             );
 
@@ -658,6 +668,8 @@ impl<S> Drop for TokenCountingStream<S> {
                     is_staff = claims.is_staff,
                     requests_this_minute = usage.requests_this_minute,
                     tokens_this_minute = usage.tokens_this_minute,
+                    input_tokens_this_minute = usage.input_tokens_this_minute,
+                    output_tokens_this_minute = usage.output_tokens_this_minute,
                 );
 
                 let properties = json!({
@@ -726,6 +738,8 @@ pub fn log_usage_periodically(state: Arc<LlmState>) {
                         model = usage.model,
                         requests_this_minute = usage.requests_this_minute,
                         tokens_this_minute = usage.tokens_this_minute,
+                        input_tokens_this_minute = usage.input_tokens_this_minute,
+                        output_tokens_this_minute = usage.output_tokens_this_minute,
                     );
                 }
             }

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -27,6 +27,8 @@ impl TokenUsage {
 pub struct Usage {
     pub requests_this_minute: usize,
     pub tokens_this_minute: usize,
+    pub input_tokens_this_minute: usize,
+    pub output_tokens_this_minute: usize,
     pub tokens_this_day: usize,
     pub tokens_this_month: TokenUsage,
     pub spending_this_month: Cents,
@@ -39,6 +41,8 @@ pub struct ApplicationWideUsage {
     pub model: String,
     pub requests_this_minute: usize,
     pub tokens_this_minute: usize,
+    pub input_tokens_this_minute: usize,
+    pub output_tokens_this_minute: usize,
 }
 
 #[derive(Clone, Copy, Debug, Default)]
@@ -94,6 +98,10 @@ impl LlmDatabase {
             let past_minute = now - Duration::minutes(1);
             let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
             let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
+            let input_tokens_per_minute =
+                self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
+            let output_tokens_per_minute =
+                self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
 
             let mut results = Vec::new();
             for ((provider, model_name), model) in self.models.iter() {
@@ -114,6 +122,8 @@ impl LlmDatabase {
 
                 let mut requests_this_minute = 0;
                 let mut tokens_this_minute = 0;
+                let mut input_tokens_this_minute = 0;
+                let mut output_tokens_this_minute = 0;
                 while let Some(usage) = usages.next().await {
                     let usage = usage?;
                     if usage.measure_id == requests_per_minute {
@@ -136,6 +146,26 @@ impl LlmDatabase {
                         .iter()
                         .copied()
                         .sum::<i64>() as usize;
+                    } else if usage.measure_id == input_tokens_per_minute {
+                        input_tokens_this_minute += Self::get_live_buckets(
+                            &usage,
+                            now.naive_utc(),
+                            UsageMeasure::InputTokensPerMinute,
+                        )
+                        .0
+                        .iter()
+                        .copied()
+                        .sum::<i64>() as usize;
+                    } else if usage.measure_id == output_tokens_per_minute {
+                        output_tokens_this_minute += Self::get_live_buckets(
+                            &usage,
+                            now.naive_utc(),
+                            UsageMeasure::OutputTokensPerMinute,
+                        )
+                        .0
+                        .iter()
+                        .copied()
+                        .sum::<i64>() as usize;
                     }
                 }
 
@@ -144,6 +174,8 @@ impl LlmDatabase {
                     model: model_name.clone(),
                     requests_this_minute,
                     tokens_this_minute,
+                    input_tokens_this_minute,
+                    output_tokens_this_minute,
                 })
             }
 
@@ -239,6 +271,10 @@ impl LlmDatabase {
                 self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
             let tokens_this_minute =
                 self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
+            let input_tokens_this_minute =
+                self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
+            let output_tokens_this_minute =
+                self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
             let tokens_this_day =
                 self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
             let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
@@ -267,6 +303,8 @@ impl LlmDatabase {
             Ok(Usage {
                 requests_this_minute,
                 tokens_this_minute,
+                input_tokens_this_minute,
+                output_tokens_this_minute,
                 tokens_this_day,
                 tokens_this_month: TokenUsage {
                     input: monthly_usage
@@ -337,6 +375,31 @@ impl LlmDatabase {
                     &tx,
                 )
                 .await?;
+            let input_tokens_this_minute = self
+                .update_usage_for_measure(
+                    user_id,
+                    is_staff,
+                    model.id,
+                    &usages,
+                    UsageMeasure::InputTokensPerMinute,
+                    now,
+                    // Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
+                    tokens.input + tokens.input_cache_creation,
+                    &tx,
+                )
+                .await?;
+            let output_tokens_this_minute = self
+                .update_usage_for_measure(
+                    user_id,
+                    is_staff,
+                    model.id,
+                    &usages,
+                    UsageMeasure::OutputTokensPerMinute,
+                    now,
+                    tokens.output,
+                    &tx,
+                )
+                .await?;
             let tokens_this_day = self
                 .update_usage_for_measure(
                     user_id,
@@ -485,6 +548,8 @@ impl LlmDatabase {
             Ok(Usage {
                 requests_this_minute,
                 tokens_this_minute,
+                input_tokens_this_minute,
+                output_tokens_this_minute,
                 tokens_this_day,
                 tokens_this_month: TokenUsage {
                     input: monthly_usage.input_tokens as usize,
@@ -684,7 +749,9 @@ impl UsageMeasure {
     fn bucket_count(&self) -> usize {
         match self {
             UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
-            UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
+            UsageMeasure::TokensPerMinute
+            | UsageMeasure::InputTokensPerMinute
+            | UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
             UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
         }
     }
@@ -692,7 +759,9 @@ impl UsageMeasure {
     fn total_duration(&self) -> Duration {
         match self {
             UsageMeasure::RequestsPerMinute => Duration::minutes(1),
-            UsageMeasure::TokensPerMinute => Duration::minutes(1),
+            UsageMeasure::TokensPerMinute
+            | UsageMeasure::InputTokensPerMinute
+            | UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
             UsageMeasure::TokensPerDay => Duration::hours(24),
         }
     }

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -83,6 +83,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         Usage {
             requests_this_minute: 2,
             tokens_this_minute: 3000,
+            input_tokens_this_minute: 3000,
+            output_tokens_this_minute: 0,
             tokens_this_day: 3000,
             tokens_this_month: TokenUsage {
                 input: 3000,
@@ -102,6 +104,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         Usage {
             requests_this_minute: 1,
             tokens_this_minute: 2000,
+            input_tokens_this_minute: 2000,
+            output_tokens_this_minute: 0,
             tokens_this_day: 3000,
             tokens_this_month: TokenUsage {
                 input: 3000,
@@ -140,6 +144,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         Usage {
             requests_this_minute: 2,
             tokens_this_minute: 5000,
+            input_tokens_this_minute: 5000,
+            output_tokens_this_minute: 0,
             tokens_this_day: 6000,
             tokens_this_month: TokenUsage {
                 input: 6000,
@@ -160,6 +166,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         Usage {
             requests_this_minute: 0,
             tokens_this_minute: 0,
+            input_tokens_this_minute: 0,
+            output_tokens_this_minute: 0,
             tokens_this_day: 5000,
             tokens_this_month: TokenUsage {
                 input: 6000,
@@ -197,6 +205,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         Usage {
             requests_this_minute: 1,
             tokens_this_minute: 4000,
+            input_tokens_this_minute: 4000,
+            output_tokens_this_minute: 0,
             tokens_this_day: 9000,
             tokens_this_month: TokenUsage {
                 input: 10000,
@@ -240,6 +250,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         Usage {
             requests_this_minute: 1,
             tokens_this_minute: 1500,
+            input_tokens_this_minute: 1500,
+            output_tokens_this_minute: 0,
             tokens_this_day: 1500,
             tokens_this_month: TokenUsage {
                 input: 1000,
@@ -278,6 +290,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         Usage {
             requests_this_minute: 2,
             tokens_this_minute: 2800,
+            input_tokens_this_minute: 2500,
+            output_tokens_this_minute: 0,
             tokens_this_day: 2800,
             tokens_this_month: TokenUsage {
                 input: 2000,