collab: Update billing code for LLM usage billing (#18879)

Marshall Bowers , Antonio Scandurra , Richard , and Max created

This PR reworks our existing billing code in preparation for charging
based on LLM usage.

We aren't yet exercising the new billing-related code outside of
development.

There are some noteworthy changes for our existing LLM usage tracking:

- A new `monthly_usages` table has been added for tracking usage
per-user, per-model, per-month
- The per-month usage measures have been removed, in favor of the
`monthly_usages` table
- All of the per-month metrics in the Clickhouse rows have been changed
from a rolling 30-day window to a calendar month

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Max <max@zed.dev>

Change summary

crates/collab/migrations_llm/20241008155620_create_monthly_usages.sql |  13 
crates/collab/src/api/billing.rs                                      | 103 
crates/collab/src/db/queries/billing_subscriptions.rs                 |  23 
crates/collab/src/lib.rs                                              |   8 
crates/collab/src/llm.rs                                              |  16 
crates/collab/src/llm/db.rs                                           |   8 
crates/collab/src/llm/db/queries/usages.rs                            | 230 
crates/collab/src/llm/db/tables.rs                                    |   1 
crates/collab/src/llm/db/tables/monthly_usage.rs                      |  22 
crates/collab/src/llm/db/tables/usage_measure.rs                      |   4 
crates/collab/src/llm/db/tests/usage_tests.rs                         |  32 
crates/collab/src/llm/token.rs                                        |   8 
crates/collab/src/main.rs                                             |  25 
crates/collab/src/rpc.rs                                              |  27 
crates/collab/src/tests/test_server.rs                                |   2 
15 files changed, 390 insertions(+), 132 deletions(-)

Detailed changes

crates/collab/migrations_llm/20241008155620_create_monthly_usages.sql 🔗

@@ -0,0 +1,13 @@
+create table monthly_usages (
+    id serial primary key,
+    user_id integer not null,
+    model_id integer not null references models (id) on delete cascade,
+    month integer not null,
+    year integer not null,
+    input_tokens bigint not null default 0,
+    cache_creation_input_tokens bigint not null default 0,
+    cache_read_input_tokens bigint not null default 0,
+    output_tokens bigint not null default 0
+);
+
+create unique index uix_monthly_usages_on_user_id_model_id_month_year on monthly_usages (user_id, model_id, month, year);

crates/collab/src/api/billing.rs 🔗

@@ -22,12 +22,15 @@ use stripe::{
 };
 use util::ResultExt;
 
-use crate::db::billing_subscription::StripeSubscriptionStatus;
+use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
 use crate::db::{
     billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
     CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
     UpdateBillingSubscriptionParams,
 };
+use crate::llm::db::LlmDatabase;
+use crate::llm::MONTHLY_SPENDING_LIMIT_IN_CENTS;
+use crate::rpc::ResultExt as _;
 use crate::{AppState, Error, Result};
 
 pub fn router() -> Router {
@@ -79,7 +82,7 @@ async fn list_billing_subscriptions(
             .into_iter()
             .map(|subscription| BillingSubscriptionJson {
                 id: subscription.id,
-                name: "Zed Pro".to_string(),
+                name: "Zed LLM Usage".to_string(),
                 status: subscription.stripe_subscription_status,
                 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
                     cancel_at
@@ -117,7 +120,7 @@ async fn create_billing_subscription(
     let Some((stripe_client, stripe_price_id)) = app
         .stripe_client
         .clone()
-        .zip(app.config.stripe_price_id.clone())
+        .zip(app.config.stripe_llm_usage_price_id.clone())
     else {
         log::error!("failed to retrieve Stripe client or price ID");
         Err(Error::http(
@@ -150,7 +153,7 @@ async fn create_billing_subscription(
         params.client_reference_id = Some(user.github_login.as_str());
         params.line_items = Some(vec![CreateCheckoutSessionLineItems {
             price: Some(stripe_price_id.to_string()),
-            quantity: Some(1),
+            quantity: Some(0),
             ..Default::default()
         }]);
         let success_url = format!("{}/account", app.config.zed_dot_dev_url());
@@ -631,3 +634,95 @@ async fn find_or_create_billing_customer(
 
     Ok(Some(billing_customer))
 }
+
+const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
+
+pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
+    let Some(stripe_client) = app.stripe_client.clone() else {
+        log::warn!("failed to retrieve Stripe client");
+        return;
+    };
+    let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
+        log::warn!("failed to retrieve Stripe LLM usage price ID");
+        return;
+    };
+
+    let executor = app.executor.clone();
+    executor.spawn_detached({
+        let executor = executor.clone();
+        async move {
+            loop {
+                sync_with_stripe(
+                    &app,
+                    &llm_db,
+                    &stripe_client,
+                    stripe_llm_usage_price_id.clone(),
+                )
+                .await
+                .trace_err();
+
+                executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
+            }
+        }
+    });
+}
+
+async fn sync_with_stripe(
+    app: &Arc<AppState>,
+    llm_db: &LlmDatabase,
+    stripe_client: &stripe::Client,
+    stripe_llm_usage_price_id: Arc<str>,
+) -> anyhow::Result<()> {
+    let subscriptions = app.db.get_active_billing_subscriptions().await?;
+
+    for (customer, subscription) in subscriptions {
+        update_stripe_subscription(
+            llm_db,
+            stripe_client,
+            &stripe_llm_usage_price_id,
+            customer,
+            subscription,
+        )
+        .await
+        .log_err();
+    }
+
+    Ok(())
+}
+
+async fn update_stripe_subscription(
+    llm_db: &LlmDatabase,
+    stripe_client: &stripe::Client,
+    stripe_llm_usage_price_id: &Arc<str>,
+    customer: billing_customer::Model,
+    subscription: billing_subscription::Model,
+) -> Result<(), anyhow::Error> {
+    let monthly_spending = llm_db
+        .get_user_spending_for_month(customer.user_id, Utc::now())
+        .await?;
+    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
+        .context("failed to parse subscription ID")?;
+
+    let monthly_spending_over_free_tier =
+        monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT_IN_CENTS);
+
+    let new_quantity = (monthly_spending_over_free_tier as f32 / 100.).ceil();
+    Subscription::update(
+        stripe_client,
+        &subscription_id,
+        stripe::UpdateSubscription {
+            items: Some(vec![stripe::UpdateSubscriptionItems {
+                // TODO: Do we need to send up the `id` if a subscription item
+                // with this price already exists, or will Stripe take care of
+                // it?
+                id: None,
+                price: Some(stripe_llm_usage_price_id.to_string()),
+                quantity: Some(new_quantity as u64),
+                ..Default::default()
+            }]),
+            ..Default::default()
+        },
+    )
+    .await?;
+    Ok(())
+}

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -112,6 +112,29 @@ impl Database {
         .await
     }
 
+    pub async fn get_active_billing_subscriptions(
+        &self,
+    ) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
+        self.transaction(|tx| async move {
+            let mut result = Vec::new();
+            let mut rows = billing_subscription::Entity::find()
+                .inner_join(billing_customer::Entity)
+                .select_also(billing_customer::Entity)
+                .order_by_asc(billing_subscription::Column::Id)
+                .stream(&*tx)
+                .await?;
+
+            while let Some(row) = rows.next().await {
+                if let (subscription, Some(customer)) = row? {
+                    result.push((customer, subscription));
+                }
+            }
+
+            Ok(result)
+        })
+        .await
+    }
+
     /// Returns whether the user has an active billing subscription.
     pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
         Ok(self.count_active_billing_subscriptions(user_id).await? > 0)

crates/collab/src/lib.rs 🔗

@@ -174,7 +174,7 @@ pub struct Config {
     pub slack_panics_webhook: Option<String>,
     pub auto_join_channel_id: Option<ChannelId>,
     pub stripe_api_key: Option<String>,
-    pub stripe_price_id: Option<Arc<str>>,
+    pub stripe_llm_usage_price_id: Option<Arc<str>>,
     pub supermaven_admin_api_key: Option<Arc<str>>,
     pub user_backfiller_github_access_token: Option<Arc<str>>,
 }
@@ -193,6 +193,10 @@ impl Config {
         }
     }
 
+    pub fn is_llm_billing_enabled(&self) -> bool {
+        self.stripe_llm_usage_price_id.is_some()
+    }
+
     #[cfg(test)]
     pub fn test() -> Self {
         Self {
@@ -231,7 +235,7 @@ impl Config {
             migrations_path: None,
             seed_path: None,
             stripe_api_key: None,
-            stripe_price_id: None,
+            stripe_llm_usage_price_id: None,
             supermaven_admin_api_key: None,
             user_backfiller_github_access_token: None,
         }

crates/collab/src/llm.rs 🔗

@@ -436,6 +436,9 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
     }
 }
 
+/// The maximum monthly spending an individual user can reach before they have to pay.
+pub const MONTHLY_SPENDING_LIMIT_IN_CENTS: usize = 5 * 100;
+
 /// The maximum lifetime spending an individual user can reach before being cut off.
 ///
 /// Represented in cents.
@@ -458,6 +461,18 @@ async fn check_usage_limit(
         )
         .await?;
 
+    if state.config.is_llm_billing_enabled() {
+        if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT_IN_CENTS {
+            if !claims.has_llm_subscription.unwrap_or(false) {
+                return Err(Error::http(
+                    StatusCode::PAYMENT_REQUIRED,
+                    "Maximum spending limit reached for this month.".to_string(),
+                ));
+            }
+        }
+    }
+
+    // TODO: Remove this once we've rolled out monthly spending limits.
     if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
         return Err(Error::http(
             StatusCode::FORBIDDEN,
@@ -505,7 +520,6 @@ async fn check_usage_limit(
                 UsageMeasure::RequestsPerMinute => "requests_per_minute",
                 UsageMeasure::TokensPerMinute => "tokens_per_minute",
                 UsageMeasure::TokensPerDay => "tokens_per_day",
-                _ => "",
             };
 
             if let Some(client) = state.clickhouse_client.as_ref() {

crates/collab/src/llm/db.rs 🔗

@@ -97,6 +97,14 @@ impl LlmDatabase {
             .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
     }
 
+    pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
+        Ok(self
+            .models
+            .values()
+            .find(|model| model.id == id)
+            .ok_or_else(|| anyhow!("no model for ID {id:?}"))?)
+    }
+
     pub fn options(&self) -> &ConnectOptions {
         &self.options
     }

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -1,5 +1,5 @@
 use crate::db::UserId;
-use chrono::Duration;
+use chrono::{Datelike, Duration};
 use futures::StreamExt as _;
 use rpc::LanguageModelProvider;
 use sea_orm::QuerySelect;
@@ -140,6 +140,46 @@ impl LlmDatabase {
         .await
     }
 
+    pub async fn get_user_spending_for_month(
+        &self,
+        user_id: UserId,
+        now: DateTimeUtc,
+    ) -> Result<usize> {
+        self.transaction(|tx| async move {
+            let month = now.date_naive().month() as i32;
+            let year = now.date_naive().year();
+
+            let mut monthly_usages = monthly_usage::Entity::find()
+                .filter(
+                    monthly_usage::Column::UserId
+                        .eq(user_id)
+                        .and(monthly_usage::Column::Month.eq(month))
+                        .and(monthly_usage::Column::Year.eq(year)),
+                )
+                .stream(&*tx)
+                .await?;
+            let mut monthly_spending_in_cents = 0;
+
+            while let Some(usage) = monthly_usages.next().await {
+                let usage = usage?;
+                let Ok(model) = self.model_by_id(usage.model_id) else {
+                    continue;
+                };
+
+                monthly_spending_in_cents += calculate_spending(
+                    model,
+                    usage.input_tokens as usize,
+                    usage.cache_creation_input_tokens as usize,
+                    usage.cache_read_input_tokens as usize,
+                    usage.output_tokens as usize,
+                );
+            }
+
+            Ok(monthly_spending_in_cents)
+        })
+        .await
+    }
+
     pub async fn get_usage(
         &self,
         user_id: UserId,
@@ -162,6 +202,18 @@ impl LlmDatabase {
                 .all(&*tx)
                 .await?;
 
+            let month = now.date_naive().month() as i32;
+            let year = now.date_naive().year();
+            let monthly_usage = monthly_usage::Entity::find()
+                .filter(
+                    monthly_usage::Column::UserId
+                        .eq(user_id)
+                        .and(monthly_usage::Column::ModelId.eq(model.id))
+                        .and(monthly_usage::Column::Month.eq(month))
+                        .and(monthly_usage::Column::Year.eq(year)),
+                )
+                .one(&*tx)
+                .await?;
             let lifetime_usage = lifetime_usage::Entity::find()
                 .filter(
                     lifetime_usage::Column::UserId
@@ -177,28 +229,18 @@ impl LlmDatabase {
                 self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
             let tokens_this_day =
                 self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
-            let input_tokens_this_month =
-                self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMonth)?;
-            let cache_creation_input_tokens_this_month = self.get_usage_for_measure(
-                &usages,
-                now,
-                UsageMeasure::CacheCreationInputTokensPerMonth,
-            )?;
-            let cache_read_input_tokens_this_month = self.get_usage_for_measure(
-                &usages,
-                now,
-                UsageMeasure::CacheReadInputTokensPerMonth,
-            )?;
-            let output_tokens_this_month =
-                self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?;
-            let spending_this_month = calculate_spending(
-                model,
-                input_tokens_this_month,
-                cache_creation_input_tokens_this_month,
-                cache_read_input_tokens_this_month,
-                output_tokens_this_month,
-            );
-            let lifetime_spending = if let Some(lifetime_usage) = lifetime_usage {
+            let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
+                calculate_spending(
+                    model,
+                    monthly_usage.input_tokens as usize,
+                    monthly_usage.cache_creation_input_tokens as usize,
+                    monthly_usage.cache_read_input_tokens as usize,
+                    monthly_usage.output_tokens as usize,
+                )
+            } else {
+                0
+            };
+            let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
                 calculate_spending(
                     model,
                     lifetime_usage.input_tokens as usize,
@@ -214,10 +256,18 @@ impl LlmDatabase {
                 requests_this_minute,
                 tokens_this_minute,
                 tokens_this_day,
-                input_tokens_this_month,
-                cache_creation_input_tokens_this_month,
-                cache_read_input_tokens_this_month,
-                output_tokens_this_month,
+                input_tokens_this_month: monthly_usage
+                    .as_ref()
+                    .map_or(0, |usage| usage.input_tokens as usize),
+                cache_creation_input_tokens_this_month: monthly_usage
+                    .as_ref()
+                    .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
+                cache_read_input_tokens_this_month: monthly_usage
+                    .as_ref()
+                    .map_or(0, |usage| usage.cache_read_input_tokens as usize),
+                output_tokens_this_month: monthly_usage
+                    .as_ref()
+                    .map_or(0, |usage| usage.output_tokens as usize),
                 spending_this_month,
                 lifetime_spending,
             })
@@ -290,60 +340,68 @@ impl LlmDatabase {
                     &tx,
                 )
                 .await?;
-            let input_tokens_this_month = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::InputTokensPerMonth,
-                    now,
-                    input_token_count,
-                    &tx,
-                )
-                .await?;
-            let cache_creation_input_tokens_this_month = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::CacheCreationInputTokensPerMonth,
-                    now,
-                    cache_creation_input_tokens,
-                    &tx,
-                )
-                .await?;
-            let cache_read_input_tokens_this_month = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::CacheReadInputTokensPerMonth,
-                    now,
-                    cache_read_input_tokens,
-                    &tx,
-                )
-                .await?;
-            let output_tokens_this_month = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::OutputTokensPerMonth,
-                    now,
-                    output_token_count,
-                    &tx,
+
+            let month = now.date_naive().month() as i32;
+            let year = now.date_naive().year();
+
+            // Update monthly usage
+            let monthly_usage = monthly_usage::Entity::find()
+                .filter(
+                    monthly_usage::Column::UserId
+                        .eq(user_id)
+                        .and(monthly_usage::Column::ModelId.eq(model.id))
+                        .and(monthly_usage::Column::Month.eq(month))
+                        .and(monthly_usage::Column::Year.eq(year)),
                 )
+                .one(&*tx)
                 .await?;
+
+            let monthly_usage = match monthly_usage {
+                Some(usage) => {
+                    monthly_usage::Entity::update(monthly_usage::ActiveModel {
+                        id: ActiveValue::unchanged(usage.id),
+                        input_tokens: ActiveValue::set(
+                            usage.input_tokens + input_token_count as i64,
+                        ),
+                        cache_creation_input_tokens: ActiveValue::set(
+                            usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
+                        ),
+                        cache_read_input_tokens: ActiveValue::set(
+                            usage.cache_read_input_tokens + cache_read_input_tokens as i64,
+                        ),
+                        output_tokens: ActiveValue::set(
+                            usage.output_tokens + output_token_count as i64,
+                        ),
+                        ..Default::default()
+                    })
+                    .exec(&*tx)
+                    .await?
+                }
+                None => {
+                    monthly_usage::ActiveModel {
+                        user_id: ActiveValue::set(user_id),
+                        model_id: ActiveValue::set(model.id),
+                        month: ActiveValue::set(month),
+                        year: ActiveValue::set(year),
+                        input_tokens: ActiveValue::set(input_token_count as i64),
+                        cache_creation_input_tokens: ActiveValue::set(
+                            cache_creation_input_tokens as i64,
+                        ),
+                        cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
+                        output_tokens: ActiveValue::set(output_token_count as i64),
+                        ..Default::default()
+                    }
+                    .insert(&*tx)
+                    .await?
+                }
+            };
+
             let spending_this_month = calculate_spending(
                 model,
-                input_tokens_this_month,
-                cache_creation_input_tokens_this_month,
-                cache_read_input_tokens_this_month,
-                output_tokens_this_month,
+                monthly_usage.input_tokens as usize,
+                monthly_usage.cache_creation_input_tokens as usize,
+                monthly_usage.cache_read_input_tokens as usize,
+                monthly_usage.output_tokens as usize,
             );
 
             // Update lifetime usage
@@ -406,10 +464,11 @@ impl LlmDatabase {
                 requests_this_minute,
                 tokens_this_minute,
                 tokens_this_day,
-                input_tokens_this_month,
-                cache_creation_input_tokens_this_month,
-                cache_read_input_tokens_this_month,
-                output_tokens_this_month,
+                input_tokens_this_month: monthly_usage.input_tokens as usize,
+                cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
+                    as usize,
+                cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
+                output_tokens_this_month: monthly_usage.output_tokens as usize,
                 spending_this_month,
                 lifetime_spending,
             })
@@ -597,7 +656,6 @@ fn calculate_spending(
 
 const MINUTE_BUCKET_COUNT: usize = 12;
 const DAY_BUCKET_COUNT: usize = 48;
-const MONTH_BUCKET_COUNT: usize = 30;
 
 impl UsageMeasure {
     fn bucket_count(&self) -> usize {
@@ -605,10 +663,6 @@ impl UsageMeasure {
             UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
             UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
             UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
-            UsageMeasure::InputTokensPerMonth => MONTH_BUCKET_COUNT,
-            UsageMeasure::CacheCreationInputTokensPerMonth => MONTH_BUCKET_COUNT,
-            UsageMeasure::CacheReadInputTokensPerMonth => MONTH_BUCKET_COUNT,
-            UsageMeasure::OutputTokensPerMonth => MONTH_BUCKET_COUNT,
         }
     }
 
@@ -617,10 +671,6 @@ impl UsageMeasure {
             UsageMeasure::RequestsPerMinute => Duration::minutes(1),
             UsageMeasure::TokensPerMinute => Duration::minutes(1),
             UsageMeasure::TokensPerDay => Duration::hours(24),
-            UsageMeasure::InputTokensPerMonth => Duration::days(30),
-            UsageMeasure::CacheCreationInputTokensPerMonth => Duration::days(30),
-            UsageMeasure::CacheReadInputTokensPerMonth => Duration::days(30),
-            UsageMeasure::OutputTokensPerMonth => Duration::days(30),
         }
     }
 

crates/collab/src/llm/db/tables/monthly_usage.rs 🔗

@@ -0,0 +1,22 @@
+use crate::{db::UserId, llm::db::ModelId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "monthly_usages")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: i32,
+    pub user_id: UserId,
+    pub model_id: ModelId,
+    pub month: i32,
+    pub year: i32,
+    pub input_tokens: i64,
+    pub cache_creation_input_tokens: i64,
+    pub cache_read_input_tokens: i64,
+    pub output_tokens: i64,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tables/usage_measure.rs 🔗

@@ -9,10 +9,6 @@ pub enum UsageMeasure {
     RequestsPerMinute,
     TokensPerMinute,
     TokensPerDay,
-    InputTokensPerMonth,
-    CacheCreationInputTokensPerMonth,
-    CacheReadInputTokensPerMonth,
-    OutputTokensPerMonth,
 }
 
 #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -6,7 +6,7 @@ use crate::{
     },
     test_llm_db,
 };
-use chrono::{Duration, Utc};
+use chrono::{DateTime, Duration, Utc};
 use pretty_assertions::assert_eq;
 use rpc::LanguageModelProvider;
 
@@ -29,7 +29,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
     .await
     .unwrap();
 
-    let t0 = Utc::now();
+    // We're using a fixed datetime to prevent flakiness based on the clock.
+    let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
+        .unwrap()
+        .with_timezone(&Utc);
     let user_id = UserId::from_proto(123);
 
     let now = t0;
@@ -134,23 +137,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         }
     );
 
-    let t2 = t0 + Duration::days(30);
-    let now = t2;
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 0,
-            tokens_this_minute: 0,
-            tokens_this_day: 0,
-            input_tokens_this_month: 9000,
-            cache_creation_input_tokens_this_month: 0,
-            cache_read_input_tokens_this_month: 0,
-            output_tokens_this_month: 0,
-            spending_this_month: 0,
-            lifetime_spending: 0,
-        }
-    );
+    // We're using a fixed datetime to prevent flakiness based on the clock.
+    let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
+        .unwrap()
+        .with_timezone(&Utc);
 
     // Test cache creation input tokens
     db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
@@ -164,7 +154,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 1,
             tokens_this_minute: 1500,
             tokens_this_day: 1500,
-            input_tokens_this_month: 10000,
+            input_tokens_this_month: 1000,
             cache_creation_input_tokens_this_month: 500,
             cache_read_input_tokens_this_month: 0,
             output_tokens_this_month: 0,
@@ -185,7 +175,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 2,
             tokens_this_minute: 2800,
             tokens_this_day: 2800,
-            input_tokens_this_month: 11000,
+            input_tokens_this_month: 2000,
             cache_creation_input_tokens_this_month: 500,
             cache_read_input_tokens_this_month: 300,
             output_tokens_this_month: 0,

crates/collab/src/llm/token.rs 🔗

@@ -22,6 +22,12 @@ pub struct LlmTokenClaims {
     pub is_staff: bool,
     #[serde(default)]
     pub has_llm_closed_beta_feature_flag: bool,
+    // This field is temporarily optional so it can be added
+    // in a backwards-compatible way. We can make it required
+    // once all of the LLM tokens have cycled (~1 hour after
+    // this change has been deployed).
+    #[serde(default)]
+    pub has_llm_subscription: Option<bool>,
     pub plan: rpc::proto::Plan,
 }
 
@@ -33,6 +39,7 @@ impl LlmTokenClaims {
         github_user_login: String,
         is_staff: bool,
         has_llm_closed_beta_feature_flag: bool,
+        has_llm_subscription: bool,
         plan: rpc::proto::Plan,
         config: &Config,
     ) -> Result<String> {
@@ -50,6 +57,7 @@ impl LlmTokenClaims {
             github_user_login: Some(github_user_login),
             is_staff,
             has_llm_closed_beta_feature_flag,
+            has_llm_subscription: Some(has_llm_subscription),
             plan,
         };
 

crates/collab/src/main.rs 🔗

@@ -6,6 +6,7 @@ use axum::{
     routing::get,
     Extension, Router,
 };
+use collab::api::billing::sync_llm_usage_with_stripe_periodically;
 use collab::api::CloudflareIpCountryHeader;
 use collab::llm::{db::LlmDatabase, log_usage_periodically};
 use collab::migrations::run_database_migrations;
@@ -29,7 +30,7 @@ use tower_http::trace::TraceLayer;
 use tracing_subscriber::{
     filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
 };
-use util::ResultExt as _;
+use util::{maybe, ResultExt as _};
 
 const VERSION: &str = env!("CARGO_PKG_VERSION");
 const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
@@ -136,6 +137,28 @@ async fn main() -> Result<()> {
                     fetch_extensions_from_blob_store_periodically(state.clone());
                     spawn_user_backfiller(state.clone());
 
+                    let llm_db = maybe!(async {
+                        let database_url = state
+                            .config
+                            .llm_database_url
+                            .as_ref()
+                            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+                        let max_connections = state
+                            .config
+                            .llm_database_max_connections
+                            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
+
+                        let mut db_options = db::ConnectOptions::new(database_url);
+                        db_options.max_connections(max_connections);
+                        LlmDatabase::new(db_options, state.executor.clone()).await
+                    })
+                    .await
+                    .trace_err();
+
+                    if let Some(llm_db) = llm_db {
+                        sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
+                    }
+
                     app = app
                         .merge(collab::api::events::router())
                         .merge(collab::api::extensions::router())

crates/collab/src/rpc.rs 🔗

@@ -191,16 +191,26 @@ impl Session {
         }
     }
 
-    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
+    pub async fn has_llm_subscription(
+        &self,
+        db: &MutexGuard<'_, DbHandle>,
+    ) -> anyhow::Result<bool> {
         if self.is_staff() {
-            return Ok(proto::Plan::ZedPro);
+            return Ok(true);
         }
 
         let Some(user_id) = self.user_id() else {
-            return Ok(proto::Plan::Free);
+            return Ok(false);
         };
 
-        if db.has_active_billing_subscription(user_id).await? {
+        Ok(db.has_active_billing_subscription(user_id).await?)
+    }
+
+    pub async fn current_plan(
+        &self,
+        _db: &MutexGuard<'_, DbHandle>,
+    ) -> anyhow::Result<proto::Plan> {
+        if self.is_staff() {
             Ok(proto::Plan::ZedPro)
         } else {
             Ok(proto::Plan::Free)
@@ -3471,7 +3481,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
 }
 
 async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
-    let plan = session.current_plan(session.db().await).await?;
+    let plan = session.current_plan(&session.db().await).await?;
 
     session
         .peer
@@ -4471,7 +4481,7 @@ async fn count_language_model_tokens(
     };
     authorize_access_to_legacy_llm_endpoints(&session).await?;
 
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
         proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
         proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
     };
@@ -4592,7 +4602,7 @@ async fn compute_embeddings(
     let api_key = api_key.context("no OpenAI API key configured on the server")?;
     authorize_access_to_legacy_llm_endpoints(&session).await?;
 
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
         proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
         proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
     };
@@ -4915,7 +4925,8 @@ async fn get_llm_api_token(
         user.github_login.clone(),
         session.is_staff(),
         has_llm_closed_beta_feature_flag,
-        session.current_plan(db).await?,
+        session.has_llm_subscription(&db).await?,
+        session.current_plan(&db).await?,
         &session.app_state.config,
     )?;
     response.send(proto::GetLlmTokenResponse { token })?;

crates/collab/src/tests/test_server.rs 🔗

@@ -677,7 +677,7 @@ impl TestServer {
                 migrations_path: None,
                 seed_path: None,
                 stripe_api_key: None,
-                stripe_price_id: None,
+                stripe_llm_usage_price_id: None,
                 supermaven_admin_api_key: None,
                 user_backfiller_github_access_token: None,
             },