collab: Sync model request overages to Stripe (#29583)

Marshall Bowers created

This PR adds syncing of model request overages to Stripe.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                              | 111 ++++
crates/collab/src/db/queries/billing_subscriptions.rs         |  32 +
crates/collab/src/llm/db/queries.rs                           |   1 
crates/collab/src/llm/db/queries/subscription_usage_meters.rs |  37 +
crates/collab/src/llm/db/queries/subscription_usages.rs       |   2 
crates/collab/src/llm/db/tables.rs                            |   1 
crates/collab/src/llm/db/tables/subscription_usage_meter.rs   |  43 +
crates/collab/src/main.rs                                     |   5 
crates/collab/src/stripe_billing.rs                           | 102 ++++
9 files changed, 318 insertions(+), 16 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -393,7 +393,9 @@ async fn create_billing_subscription(
                 zed_llm_client::LanguageModelProvider::Anthropic,
                 "claude-3-7-sonnet",
             )?;
-            let stripe_model = stripe_billing.register_model(default_model).await?;
+            let stripe_model = stripe_billing
+                .register_model_for_token_based_usage(default_model)
+                .await?;
             stripe_billing
                 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
                 .await?
@@ -1303,7 +1305,9 @@ async fn sync_token_usage_with_stripe(
             .parse()
             .context("failed to parse stripe customer id from db")?;
 
-        let stripe_model = stripe_billing.register_model(&model).await?;
+        let stripe_model = stripe_billing
+            .register_model_for_token_based_usage(&model)
+            .await?;
         stripe_billing
             .subscribe_to_model(&stripe_subscription_id, &stripe_model)
             .await?;
@@ -1315,3 +1319,106 @@ async fn sync_token_usage_with_stripe(
 
     Ok(())
 }
+
+const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
+
+pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
+    let Some(stripe_billing) = app.stripe_billing.clone() else {
+        log::warn!("failed to retrieve Stripe billing object");
+        return;
+    };
+    let Some(llm_db) = app.llm_db.clone() else {
+        log::warn!("failed to retrieve LLM database");
+        return;
+    };
+
+    let executor = app.executor.clone();
+    executor.spawn_detached({
+        let executor = executor.clone();
+        async move {
+            loop {
+                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
+                    .await
+                    .context("failed to sync LLM request usage to Stripe")
+                    .trace_err();
+                executor
+                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
+                    .await;
+            }
+        }
+    });
+}
+
+async fn sync_model_request_usage_with_stripe(
+    app: &Arc<AppState>,
+    llm_db: &Arc<LlmDatabase>,
+    stripe_billing: &Arc<StripeBilling>,
+) -> anyhow::Result<()> {
+    let usage_meters = llm_db
+        .get_current_subscription_usage_meters(Utc::now())
+        .await?;
+    let user_ids = usage_meters
+        .iter()
+        .map(|(_, usage)| usage.user_id)
+        .collect::<HashSet<UserId>>();
+    let billing_subscriptions = app
+        .db
+        .get_active_zed_pro_billing_subscriptions(user_ids)
+        .await?;
+
+    let claude_3_5_sonnet = stripe_billing
+        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
+        .await?;
+    let claude_3_7_sonnet = stripe_billing
+        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
+        .await?;
+
+    for (usage_meter, usage) in usage_meters {
+        maybe!(async {
+            let Some((billing_customer, billing_subscription)) =
+                billing_subscriptions.get(&usage.user_id)
+            else {
+                bail!(
+                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
+                    usage.user_id
+                );
+            };
+
+            let stripe_customer_id = billing_customer
+                .stripe_customer_id
+                .parse::<stripe::CustomerId>()
+                .context("failed to parse Stripe customer ID from database")?;
+            let stripe_subscription_id = billing_subscription
+                .stripe_subscription_id
+                .parse::<stripe::SubscriptionId>()
+                .context("failed to parse Stripe subscription ID from database")?;
+
+            let model = llm_db.model_by_id(usage_meter.model_id)?;
+
+            let (price_id, meter_event_name) = match model.name.as_str() {
+                "claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
+                "claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
+                model_name => {
+                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
+                }
+            };
+
+            stripe_billing
+                .subscribe_to_price(&stripe_subscription_id, price_id)
+                .await?;
+            stripe_billing
+                .bill_model_request_usage(
+                    &stripe_customer_id,
+                    meter_event_name,
+                    usage_meter.requests,
+                )
+                .await?;
+
+            Ok(())
+        })
+        .await
+        .log_err();
+    }
+
+    Ok(())
+}

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -191,6 +191,38 @@ impl Database {
         .await
     }
 
+    pub async fn get_active_zed_pro_billing_subscriptions(
+        &self,
+        user_ids: HashSet<UserId>,
+    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
+        self.transaction(|tx| {
+            let user_ids = user_ids.clone();
+            async move {
+                let mut rows = billing_subscription::Entity::find()
+                    .inner_join(billing_customer::Entity)
+                    .select_also(billing_customer::Entity)
+                    .filter(billing_customer::Column::UserId.is_in(user_ids))
+                    .filter(
+                        billing_subscription::Column::StripeSubscriptionStatus
+                            .eq(StripeSubscriptionStatus::Active),
+                    )
+                    .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
+                    .order_by_asc(billing_subscription::Column::Id)
+                    .stream(&*tx)
+                    .await?;
+
+                let mut subscriptions = HashMap::default();
+                while let Some(row) = rows.next().await {
+                    if let (subscription, Some(customer)) = row? {
+                        subscriptions.insert(customer.user_id, (customer, subscription));
+                    }
+                }
+                Ok(subscriptions)
+            }
+        })
+        .await
+    }
+
     /// Returns whether the user has an active billing subscription.
     pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
         Ok(self.count_active_billing_subscriptions(user_id).await? > 0)

crates/collab/src/llm/db/queries.rs 🔗

@@ -2,5 +2,6 @@ use super::*;
 
 pub mod billing_events;
 pub mod providers;
+pub mod subscription_usage_meters;
 pub mod subscription_usages;
 pub mod usages;

crates/collab/src/llm/db/queries/subscription_usage_meters.rs 🔗

@@ -0,0 +1,37 @@
+use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
+
+use super::*;
+
+impl LlmDatabase {
+    /// Returns all current subscription usage meters as of the given timestamp.
+    pub async fn get_current_subscription_usage_meters(
+        &self,
+        now: DateTimeUtc,
+    ) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
+        let now = convert_chrono_to_time(now)?;
+
+        self.transaction(|tx| async move {
+            let result = subscription_usage_meter::Entity::find()
+                .inner_join(subscription_usage::Entity)
+                .filter(
+                    subscription_usage::Column::PeriodStartAt
+                        .lte(now)
+                        .and(subscription_usage::Column::PeriodEndAt.gte(now)),
+                )
+                .select_also(subscription_usage::Entity)
+                .all(&*tx)
+                .await?;
+
+            let result = result
+                .into_iter()
+                .filter_map(|(meter, usage)| {
+                    let usage = usage?;
+                    Some((meter, usage))
+                })
+                .collect();
+
+            Ok(result)
+        })
+        .await
+    }
+}

crates/collab/src/llm/db/queries/subscription_usages.rs 🔗

@@ -6,7 +6,7 @@ use crate::db::{UserId, billing_subscription};
 
 use super::*;
 
-fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
+pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
     use chrono::{Datelike as _, Timelike as _};
 
     let date = time::Date::from_calendar_date(

crates/collab/src/llm/db/tables.rs 🔗

@@ -3,5 +3,6 @@ pub mod model;
 pub mod monthly_usage;
 pub mod provider;
 pub mod subscription_usage;
+pub mod subscription_usage_meter;
 pub mod usage;
 pub mod usage_measure;

crates/collab/src/llm/db/tables/subscription_usage_meter.rs 🔗

@@ -0,0 +1,43 @@
+use sea_orm::entity::prelude::*;
+
+use crate::llm::db::ModelId;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "subscription_usage_meters")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: i32,
+    pub subscription_usage_id: i32,
+    pub model_id: ModelId,
+    pub requests: i32,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::subscription_usage::Entity",
+        from = "Column::SubscriptionUsageId",
+        to = "super::subscription_usage::Column::Id"
+    )]
+    SubscriptionUsage,
+    #[sea_orm(
+        belongs_to = "super::model::Entity",
+        from = "Column::ModelId",
+        to = "super::model::Column::Id"
+    )]
+    Model,
+}
+
+impl Related<super::subscription_usage::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::SubscriptionUsage.def()
+    }
+}
+
+impl Related<super::model::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Model.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/main.rs 🔗

@@ -8,7 +8,9 @@ use axum::{
 };
 
 use collab::api::CloudflareIpCountryHeader;
-use collab::api::billing::sync_llm_token_usage_with_stripe_periodically;
+use collab::api::billing::{
+    sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically,
+};
 use collab::llm::db::LlmDatabase;
 use collab::migrations::run_database_migrations;
 use collab::user_backfiller::spawn_user_backfiller;
@@ -152,6 +154,7 @@ async fn main() -> Result<()> {
 
                     if let Some(mut llm_db) = llm_db {
                         llm_db.initialize().await?;
+                        sync_llm_request_usage_with_stripe_periodically(state.clone());
                         sync_llm_token_usage_with_stripe_periodically(state.clone());
                     }
 

crates/collab/src/stripe_billing.rs 🔗

@@ -1,12 +1,13 @@
 use std::sync::Arc;
 
 use crate::{Cents, Result, llm};
-use anyhow::Context as _;
+use anyhow::{Context as _, anyhow};
 use chrono::{Datelike, Utc};
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
 use stripe::PriceId;
 use tokio::sync::RwLock;
+use uuid::Uuid;
 
 pub struct StripeBilling {
     state: RwLock<StripeBillingState>,
@@ -17,9 +18,10 @@ pub struct StripeBilling {
 struct StripeBillingState {
     meters_by_event_name: HashMap<String, StripeMeter>,
     price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+    prices_by_lookup_key: HashMap<String, stripe::Price>,
 }
 
-pub struct StripeModel {
+pub struct StripeModelTokenPrices {
     input_tokens_price: StripeBillingPrice,
     input_cache_creation_tokens_price: StripeBillingPrice,
     input_cache_read_tokens_price: StripeBillingPrice,
@@ -62,6 +64,10 @@ impl StripeBilling {
         }
 
         for price in prices.data {
+            if let Some(lookup_key) = price.lookup_key.clone() {
+                state.prices_by_lookup_key.insert(lookup_key, price.clone());
+            }
+
             if let Some(recurring) = price.recurring {
                 if let Some(meter) = recurring.meter {
                     state.price_ids_by_meter_id.insert(meter, price.id);
@@ -74,36 +80,49 @@ impl StripeBilling {
         Ok(())
     }
 
-    pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
+    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
+        self.state
+            .read()
+            .await
+            .prices_by_lookup_key
+            .get(lookup_key)
+            .cloned()
+            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
+    }
+
+    pub async fn register_model_for_token_based_usage(
+        &self,
+        model: &llm::db::model::Model,
+    ) -> Result<StripeModelTokenPrices> {
         let input_tokens_price = self
-            .get_or_insert_price(
+            .get_or_insert_token_price(
                 &format!("model_{}/input_tokens", model.id),
                 &format!("{} (Input Tokens)", model.name),
                 Cents::new(model.price_per_million_input_tokens as u32),
             )
             .await?;
         let input_cache_creation_tokens_price = self
-            .get_or_insert_price(
+            .get_or_insert_token_price(
                 &format!("model_{}/input_cache_creation_tokens", model.id),
                 &format!("{} (Input Cache Creation Tokens)", model.name),
                 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
             )
             .await?;
         let input_cache_read_tokens_price = self
-            .get_or_insert_price(
+            .get_or_insert_token_price(
                 &format!("model_{}/input_cache_read_tokens", model.id),
                 &format!("{} (Input Cache Read Tokens)", model.name),
                 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
             )
             .await?;
         let output_tokens_price = self
-            .get_or_insert_price(
+            .get_or_insert_token_price(
                 &format!("model_{}/output_tokens", model.id),
                 &format!("{} (Output Tokens)", model.name),
                 Cents::new(model.price_per_million_output_tokens as u32),
             )
             .await?;
-        Ok(StripeModel {
+        Ok(StripeModelTokenPrices {
             input_tokens_price,
             input_cache_creation_tokens_price,
             input_cache_read_tokens_price,
@@ -111,7 +130,7 @@ impl StripeBilling {
         })
     }
 
-    async fn get_or_insert_price(
+    async fn get_or_insert_token_price(
         &self,
         meter_event_name: &str,
         price_description: &str,
@@ -207,10 +226,43 @@ impl StripeBilling {
         })
     }
 
+    pub async fn subscribe_to_price(
+        &self,
+        subscription_id: &stripe::SubscriptionId,
+        price_id: &stripe::PriceId,
+    ) -> Result<()> {
+        let subscription =
+            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+
+        if subscription_contains_price(&subscription, price_id) {
+            return Ok(());
+        }
+
+        stripe::Subscription::update(
+            &self.client,
+            subscription_id,
+            stripe::UpdateSubscription {
+                items: Some(vec![stripe::UpdateSubscriptionItems {
+                    price: Some(price_id.to_string()),
+                    ..Default::default()
+                }]),
+                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
+                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
+                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+                    },
+                }),
+                ..Default::default()
+            },
+        )
+        .await?;
+
+        Ok(())
+    }
+
     pub async fn subscribe_to_model(
         &self,
         subscription_id: &stripe::SubscriptionId,
-        model: &StripeModel,
+        model: &StripeModelTokenPrices,
     ) -> Result<()> {
         let subscription =
             stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
@@ -271,7 +323,7 @@ impl StripeBilling {
     pub async fn bill_model_token_usage(
         &self,
         customer_id: &stripe::CustomerId,
-        model: &StripeModel,
+        model: &StripeModelTokenPrices,
         event: &llm::db::billing_event::Model,
     ) -> Result<()> {
         let timestamp = Utc::now().timestamp();
@@ -343,11 +395,37 @@ impl StripeBilling {
         Ok(())
     }
 
+    pub async fn bill_model_request_usage(
+        &self,
+        customer_id: &stripe::CustomerId,
+        event_name: &str,
+        requests: i32,
+    ) -> Result<()> {
+        let timestamp = Utc::now().timestamp();
+        let idempotency_key = Uuid::new_v4();
+
+        StripeMeterEvent::create(
+            &self.client,
+            StripeCreateMeterEventParams {
+                identifier: &format!("model_requests/{}", idempotency_key),
+                event_name,
+                payload: StripeCreateMeterEventPayload {
+                    value: requests as u64,
+                    stripe_customer_id: customer_id,
+                },
+                timestamp: Some(timestamp),
+            },
+        )
+        .await?;
+
+        Ok(())
+    }
+
     pub async fn checkout(
         &self,
         customer_id: stripe::CustomerId,
         github_login: &str,
-        model: &StripeModel,
+        model: &StripeModelTokenPrices,
         success_url: &str,
     ) -> Result<String> {
         let first_of_next_month = Utc::now()