Track lifetime spending for each user and model (#16137)

Max Brunsfeld and Marshall created

Release Notes:

- N/A

Co-authored-by: Marshall <marshall@zed.dev>

Change summary

crates/collab/migrations_llm/20240812225346_create_lifetime_usages.sql |  9 
crates/collab/src/llm.rs                                               | 14 
crates/collab/src/llm/db/queries/usages.rs                             | 68 
crates/collab/src/llm/db/tables.rs                                     |  1 
crates/collab/src/llm/db/tables/lifetime_usage.rs                      | 18 
crates/collab/src/llm/db/tables/usage.rs                               |  7 
crates/collab/src/llm/db/tests/usage_tests.rs                          | 14 
crates/collab/src/llm/telemetry.rs                                     |  1 
8 files changed, 121 insertions(+), 11 deletions(-)

Detailed changes

crates/collab/migrations_llm/20240812225346_create_lifetime_usages.sql 🔗

@@ -0,0 +1,9 @@
+create table lifetime_usages (
+    id serial primary key,
+    user_id integer not null,
+    model_id integer not null references models (id) on delete cascade,
+    input_tokens bigint not null default 0,
+    output_tokens bigint not null default 0
+);
+
+create unique index uix_lifetime_usages_on_user_id_model_id on lifetime_usages (user_id, model_id);

crates/collab/src/llm.rs 🔗

@@ -4,8 +4,8 @@ mod telemetry;
 mod token;
 
 use crate::{
-    api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error,
-    Result,
+    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
+    Config, Error, Result,
 };
 use anyhow::{anyhow, Context as _};
 use authorization::authorize_access_to_language_model;
@@ -396,7 +396,12 @@ async fn check_usage_limit(
     let model = state.db.model(provider, model_name)?;
     let usage = state
         .db
-        .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
+        .get_usage(
+            UserId::from_proto(claims.user_id),
+            provider,
+            model_name,
+            Utc::now(),
+        )
         .await?;
 
     let active_users = state.get_active_user_count().await?;
@@ -523,7 +528,7 @@ impl<S> Drop for TokenCountingStream<S> {
             let usage = state
                 .db
                 .record_usage(
-                    claims.user_id as i32,
+                    UserId::from_proto(claims.user_id),
                     claims.is_staff,
                     provider,
                     &model,
@@ -555,6 +560,7 @@ impl<S> Drop for TokenCountingStream<S> {
                         input_tokens_this_month: usage.input_tokens_this_month as u64,
                         output_tokens_this_month: usage.output_tokens_this_month as u64,
                         spending_this_month: usage.spending_this_month as u64,
+                        lifetime_spending: usage.lifetime_spending as u64,
                     },
                 )
                 .await

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -1,3 +1,4 @@
+use crate::db::UserId;
 use chrono::Duration;
 use rpc::LanguageModelProvider;
 use sea_orm::QuerySelect;
@@ -14,6 +15,7 @@ pub struct Usage {
     pub input_tokens_this_month: usize,
     pub output_tokens_this_month: usize,
     pub spending_this_month: usize,
+    pub lifetime_spending: usize,
 }
 
 #[derive(Clone, Copy, Debug, Default)]
@@ -63,7 +65,7 @@ impl LlmDatabase {
 
     pub async fn get_usage(
         &self,
-        user_id: i32,
+        user_id: UserId,
         provider: LanguageModelProvider,
         model_name: &str,
         now: DateTimeUtc,
@@ -83,6 +85,18 @@ impl LlmDatabase {
                 .all(&*tx)
                 .await?;
 
+            let (lifetime_input_tokens, lifetime_output_tokens) = lifetime_usage::Entity::find()
+                .filter(
+                    lifetime_usage::Column::UserId
+                        .eq(user_id)
+                        .and(lifetime_usage::Column::ModelId.eq(model.id)),
+                )
+                .one(&*tx)
+                .await?
+                .map_or((0, 0), |usage| {
+                    (usage.input_tokens as usize, usage.output_tokens as usize)
+                });
+
             let requests_this_minute =
                 self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
             let tokens_this_minute =
@@ -95,6 +109,8 @@ impl LlmDatabase {
                 self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?;
             let spending_this_month =
                 calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
+            let lifetime_spending =
+                calculate_spending(model, lifetime_input_tokens, lifetime_output_tokens);
 
             Ok(Usage {
                 requests_this_minute,
@@ -103,6 +119,7 @@ impl LlmDatabase {
                 input_tokens_this_month,
                 output_tokens_this_month,
                 spending_this_month,
+                lifetime_spending,
             })
         })
         .await
@@ -111,7 +128,7 @@ impl LlmDatabase {
     #[allow(clippy::too_many_arguments)]
     pub async fn record_usage(
         &self,
-        user_id: i32,
+        user_id: UserId,
         is_staff: bool,
         provider: LanguageModelProvider,
         model_name: &str,
@@ -194,6 +211,50 @@ impl LlmDatabase {
             let spending_this_month =
                 calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
 
+            // Update lifetime usage
+            let lifetime_usage = lifetime_usage::Entity::find()
+                .filter(
+                    lifetime_usage::Column::UserId
+                        .eq(user_id)
+                        .and(lifetime_usage::Column::ModelId.eq(model.id)),
+                )
+                .one(&*tx)
+                .await?;
+
+            let lifetime_usage = match lifetime_usage {
+                Some(usage) => {
+                    lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
+                        id: ActiveValue::unchanged(usage.id),
+                        input_tokens: ActiveValue::set(
+                            usage.input_tokens + input_token_count as i64,
+                        ),
+                        output_tokens: ActiveValue::set(
+                            usage.output_tokens + output_token_count as i64,
+                        ),
+                        ..Default::default()
+                    })
+                    .exec(&*tx)
+                    .await?
+                }
+                None => {
+                    lifetime_usage::ActiveModel {
+                        user_id: ActiveValue::set(user_id),
+                        model_id: ActiveValue::set(model.id),
+                        input_tokens: ActiveValue::set(input_token_count as i64),
+                        output_tokens: ActiveValue::set(output_token_count as i64),
+                        ..Default::default()
+                    }
+                    .insert(&*tx)
+                    .await?
+                }
+            };
+
+            let lifetime_spending = calculate_spending(
+                model,
+                lifetime_usage.input_tokens as usize,
+                lifetime_usage.output_tokens as usize,
+            );
+
             Ok(Usage {
                 requests_this_minute,
                 tokens_this_minute,
@@ -201,6 +262,7 @@ impl LlmDatabase {
                 input_tokens_this_month,
                 output_tokens_this_month,
                 spending_this_month,
+                lifetime_spending,
             })
         })
         .await
@@ -246,7 +308,7 @@ impl LlmDatabase {
     #[allow(clippy::too_many_arguments)]
     async fn update_usage_for_measure(
         &self,
-        user_id: i32,
+        user_id: UserId,
         is_staff: bool,
         model_id: ModelId,
         usages: &[usage::Model],

crates/collab/src/llm/db/tables/lifetime_usage.rs 🔗

@@ -0,0 +1,18 @@
+use crate::{db::UserId, llm::db::ModelId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "lifetime_usages")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: i32,
+    pub user_id: UserId,
+    pub model_id: ModelId,
+    pub input_tokens: i64,
+    pub output_tokens: i64,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tables/usage.rs 🔗

@@ -1,4 +1,7 @@
-use crate::llm::db::{ModelId, UsageId, UsageMeasureId};
+use crate::{
+    db::UserId,
+    llm::db::{ModelId, UsageId, UsageMeasureId},
+};
 use sea_orm::entity::prelude::*;
 
 /// An LLM usage record.
@@ -10,7 +13,7 @@ pub struct Model {
     /// The ID of the Zed user.
     ///
     /// Corresponds to the `users` table in the primary collab database.
-    pub user_id: i32,
+    pub user_id: UserId,
     pub model_id: ModelId,
     pub measure_id: UsageMeasureId,
     pub timestamp: DateTime,

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -1,5 +1,9 @@
 use crate::{
-    llm::db::{queries::providers::ModelParams, queries::usages::Usage, LlmDatabase},
+    db::UserId,
+    llm::db::{
+        queries::{providers::ModelParams, usages::Usage},
+        LlmDatabase,
+    },
     test_llm_db,
 };
 use chrono::{Duration, Utc};
@@ -26,7 +30,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
     .unwrap();
 
     let t0 = Utc::now();
-    let user_id = 123;
+    let user_id = UserId::from_proto(123);
 
     let now = t0;
     db.record_usage(user_id, false, provider, model, 1000, 0, now)
@@ -48,6 +52,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             input_tokens_this_month: 3000,
             output_tokens_this_month: 0,
             spending_this_month: 0,
+            lifetime_spending: 0,
         }
     );
 
@@ -62,6 +67,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             input_tokens_this_month: 3000,
             output_tokens_this_month: 0,
             spending_this_month: 0,
+            lifetime_spending: 0,
         }
     );
 
@@ -80,6 +86,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             input_tokens_this_month: 6000,
             output_tokens_this_month: 0,
             spending_this_month: 0,
+            lifetime_spending: 0,
         }
     );
 
@@ -95,6 +102,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             input_tokens_this_month: 6000,
             output_tokens_this_month: 0,
             spending_this_month: 0,
+            lifetime_spending: 0,
         }
     );
 
@@ -112,6 +120,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             input_tokens_this_month: 10000,
             output_tokens_this_month: 0,
             spending_this_month: 0,
+            lifetime_spending: 0,
         }
     );
 
@@ -127,6 +136,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             input_tokens_this_month: 9000,
             output_tokens_this_month: 0,
             spending_this_month: 0,
+            lifetime_spending: 0,
         }
     );
 }

crates/collab/src/llm/telemetry.rs 🔗

@@ -17,6 +17,7 @@ pub struct LlmUsageEventRow {
     pub input_tokens_this_month: u64,
     pub output_tokens_this_month: u64,
     pub spending_this_month: u64,
+    pub lifetime_spending: u64,
 }
 
 #[derive(Serialize, Debug, clickhouse::Row)]