collab: Remove code for syncing token-based billing events (#30130)

Marshall Bowers created

This PR removes the code related to syncing token-based billing events
to Stripe.

We don't need this anymore with the new billing.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                   |  96 ----
crates/collab/src/llm/db/queries.rs                |   1 
crates/collab/src/llm/db/queries/billing_events.rs |  31 -
crates/collab/src/llm/db/tables.rs                 |   1 
crates/collab/src/llm/db/tables/billing_event.rs   |  37 -
crates/collab/src/llm/db/tables/model.rs           |   8 
crates/collab/src/main.rs                          |   5 
crates/collab/src/stripe_billing.rs                | 350 ---------------
8 files changed, 8 insertions(+), 521 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -301,13 +301,6 @@ async fn create_billing_subscription(
             "not supported".into(),
         ))?
     };
-    let Some(llm_db) = app.llm_db.clone() else {
-        log::error!("failed to retrieve LLM database");
-        Err(Error::http(
-            StatusCode::NOT_IMPLEMENTED,
-            "not supported".into(),
-        ))?
-    };
 
     if app.db.has_active_billing_subscription(user.id).await? {
         return Err(Error::http(
@@ -399,16 +392,10 @@ async fn create_billing_subscription(
                 .await?
         }
         None => {
-            let default_model = llm_db.model(
-                zed_llm_client::LanguageModelProvider::Anthropic,
-                "claude-3-7-sonnet",
-            )?;
-            let stripe_model = stripe_billing
-                .register_model_for_token_based_usage(default_model)
-                .await?;
-            stripe_billing
-                .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
-                .await?
+            return Err(Error::http(
+                StatusCode::BAD_REQUEST,
+                "No product selected".into(),
+            ));
         }
     };
 
@@ -1381,81 +1368,6 @@ async fn find_or_create_billing_customer(
     Ok(Some(billing_customer))
 }
 
-const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
-
-pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
-    let Some(stripe_billing) = app.stripe_billing.clone() else {
-        log::warn!("failed to retrieve Stripe billing object");
-        return;
-    };
-    let Some(llm_db) = app.llm_db.clone() else {
-        log::warn!("failed to retrieve LLM database");
-        return;
-    };
-
-    let executor = app.executor.clone();
-    executor.spawn_detached({
-        let executor = executor.clone();
-        async move {
-            loop {
-                sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
-                    .await
-                    .context("failed to sync LLM usage to Stripe")
-                    .trace_err();
-                executor
-                    .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
-                    .await;
-            }
-        }
-    });
-}
-
-async fn sync_token_usage_with_stripe(
-    app: &Arc<AppState>,
-    llm_db: &Arc<LlmDatabase>,
-    stripe_billing: &Arc<StripeBilling>,
-) -> anyhow::Result<()> {
-    let events = llm_db.get_billing_events().await?;
-    let user_ids = events
-        .iter()
-        .map(|(event, _)| event.user_id)
-        .collect::<HashSet<UserId>>();
-    let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
-
-    for (event, model) in events {
-        let Some((stripe_db_customer, stripe_db_subscription)) =
-            stripe_subscriptions.get(&event.user_id)
-        else {
-            tracing::warn!(
-                user_id = event.user_id.0,
-                "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
-            );
-            continue;
-        };
-        let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
-            .stripe_subscription_id
-            .parse()
-            .context("failed to parse stripe subscription id from db")?;
-        let stripe_customer_id: stripe::CustomerId = stripe_db_customer
-            .stripe_customer_id
-            .parse()
-            .context("failed to parse stripe customer id from db")?;
-
-        let stripe_model = stripe_billing
-            .register_model_for_token_based_usage(&model)
-            .await?;
-        stripe_billing
-            .subscribe_to_model(&stripe_subscription_id, &stripe_model)
-            .await?;
-        stripe_billing
-            .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
-            .await?;
-        llm_db.consume_billing_event(event.id).await?;
-    }
-
-    Ok(())
-}
-
 const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
 
 pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {

crates/collab/src/llm/db/queries/billing_events.rs 🔗

@@ -1,31 +0,0 @@
-use super::*;
-use crate::Result;
-use anyhow::Context as _;
-
-impl LlmDatabase {
-    pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
-        self.transaction(|tx| async move {
-            let events_with_models = billing_event::Entity::find()
-                .find_also_related(model::Entity)
-                .all(&*tx)
-                .await?;
-            events_with_models
-                .into_iter()
-                .map(|(event, model)| {
-                    let model =
-                        model.context("could not find model associated with billing event")?;
-                    Ok((event, model))
-                })
-                .collect()
-        })
-        .await
-    }
-
-    pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
-        self.transaction(|tx| async move {
-            billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
-            Ok(())
-        })
-        .await
-    }
-}

crates/collab/src/llm/db/tables/billing_event.rs 🔗

@@ -1,37 +0,0 @@
-use crate::{
-    db::UserId,
-    llm::db::{BillingEventId, ModelId},
-};
-use sea_orm::entity::prelude::*;
-
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "billing_events")]
-pub struct Model {
-    #[sea_orm(primary_key)]
-    pub id: BillingEventId,
-    pub idempotency_key: Uuid,
-    pub user_id: UserId,
-    pub model_id: ModelId,
-    pub input_tokens: i64,
-    pub input_cache_creation_tokens: i64,
-    pub input_cache_read_tokens: i64,
-    pub output_tokens: i64,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
-    #[sea_orm(
-        belongs_to = "super::model::Entity",
-        from = "Column::ModelId",
-        to = "super::model::Column::Id"
-    )]
-    Model,
-}
-
-impl Related<super::model::Entity> for Entity {
-    fn to() -> RelationDef {
-        Relation::Model.def()
-    }
-}
-
-impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tables/model.rs 🔗

@@ -31,8 +31,6 @@ pub enum Relation {
     Provider,
     #[sea_orm(has_many = "super::usage::Entity")]
     Usages,
-    #[sea_orm(has_many = "super::billing_event::Entity")]
-    BillingEvents,
 }
 
 impl Related<super::provider::Entity> for Entity {
@@ -47,10 +45,4 @@ impl Related<super::usage::Entity> for Entity {
     }
 }
 
-impl Related<super::billing_event::Entity> for Entity {
-    fn to() -> RelationDef {
-        Relation::BillingEvents.def()
-    }
-}
-
 impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/main.rs 🔗

@@ -8,9 +8,7 @@ use axum::{
 };
 
 use collab::api::CloudflareIpCountryHeader;
-use collab::api::billing::{
-    sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically,
-};
+use collab::api::billing::sync_llm_request_usage_with_stripe_periodically;
 use collab::llm::db::LlmDatabase;
 use collab::migrations::run_database_migrations;
 use collab::user_backfiller::spawn_user_backfiller;
@@ -155,7 +153,6 @@ async fn main() -> Result<()> {
                     if let Some(mut llm_db) = llm_db {
                         llm_db.initialize().await?;
                         sync_llm_request_usage_with_stripe_periodically(state.clone());
-                        sync_llm_token_usage_with_stripe_periodically(state.clone());
                     }
 
                     app = app

crates/collab/src/stripe_billing.rs 🔗

@@ -1,9 +1,9 @@
 use std::sync::Arc;
 
-use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
-use crate::{Cents, Result};
+use crate::Result;
+use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 use anyhow::{Context as _, anyhow};
-use chrono::{Datelike, Utc};
+use chrono::Utc;
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
 use stripe::PriceId;
@@ -22,18 +22,6 @@ struct StripeBillingState {
     prices_by_lookup_key: HashMap<String, stripe::Price>,
 }
 
-pub struct StripeModelTokenPrices {
-    input_tokens_price: StripeBillingPrice,
-    input_cache_creation_tokens_price: StripeBillingPrice,
-    input_cache_read_tokens_price: StripeBillingPrice,
-    output_tokens_price: StripeBillingPrice,
-}
-
-struct StripeBillingPrice {
-    id: stripe::PriceId,
-    meter_event_name: String,
-}
-
 impl StripeBilling {
     pub fn new(client: Arc<stripe::Client>) -> Self {
         Self {
@@ -109,142 +97,6 @@ impl StripeBilling {
             .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
     }
 
-    pub async fn register_model_for_token_based_usage(
-        &self,
-        model: &llm::db::model::Model,
-    ) -> Result<StripeModelTokenPrices> {
-        let input_tokens_price = self
-            .get_or_insert_token_price(
-                &format!("model_{}/input_tokens", model.id),
-                &format!("{} (Input Tokens)", model.name),
-                Cents::new(model.price_per_million_input_tokens as u32),
-            )
-            .await?;
-        let input_cache_creation_tokens_price = self
-            .get_or_insert_token_price(
-                &format!("model_{}/input_cache_creation_tokens", model.id),
-                &format!("{} (Input Cache Creation Tokens)", model.name),
-                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
-            )
-            .await?;
-        let input_cache_read_tokens_price = self
-            .get_or_insert_token_price(
-                &format!("model_{}/input_cache_read_tokens", model.id),
-                &format!("{} (Input Cache Read Tokens)", model.name),
-                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
-            )
-            .await?;
-        let output_tokens_price = self
-            .get_or_insert_token_price(
-                &format!("model_{}/output_tokens", model.id),
-                &format!("{} (Output Tokens)", model.name),
-                Cents::new(model.price_per_million_output_tokens as u32),
-            )
-            .await?;
-        Ok(StripeModelTokenPrices {
-            input_tokens_price,
-            input_cache_creation_tokens_price,
-            input_cache_read_tokens_price,
-            output_tokens_price,
-        })
-    }
-
-    async fn get_or_insert_token_price(
-        &self,
-        meter_event_name: &str,
-        price_description: &str,
-        price_per_million_tokens: Cents,
-    ) -> Result<StripeBillingPrice> {
-        // Fast code path when the meter and the price already exist.
-        {
-            let state = self.state.read().await;
-            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
-                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
-                    return Ok(StripeBillingPrice {
-                        id: price_id.clone(),
-                        meter_event_name: meter_event_name.to_string(),
-                    });
-                }
-            }
-        }
-
-        let mut state = self.state.write().await;
-        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
-            meter.clone()
-        } else {
-            let meter = StripeMeter::create(
-                &self.client,
-                StripeCreateMeterParams {
-                    default_aggregation: DefaultAggregation { formula: "sum" },
-                    display_name: price_description.to_string(),
-                    event_name: meter_event_name,
-                },
-            )
-            .await?;
-            state
-                .meters_by_event_name
-                .insert(meter_event_name.to_string(), meter.clone());
-            meter
-        };
-
-        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
-            price_id.clone()
-        } else {
-            let price = stripe::Price::create(
-                &self.client,
-                stripe::CreatePrice {
-                    active: Some(true),
-                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
-                    currency: stripe::Currency::USD,
-                    currency_options: None,
-                    custom_unit_amount: None,
-                    expand: &[],
-                    lookup_key: None,
-                    metadata: None,
-                    nickname: None,
-                    product: None,
-                    product_data: Some(stripe::CreatePriceProductData {
-                        id: None,
-                        active: Some(true),
-                        metadata: None,
-                        name: price_description.to_string(),
-                        statement_descriptor: None,
-                        tax_code: None,
-                        unit_label: None,
-                    }),
-                    recurring: Some(stripe::CreatePriceRecurring {
-                        aggregate_usage: None,
-                        interval: stripe::CreatePriceRecurringInterval::Month,
-                        interval_count: None,
-                        trial_period_days: None,
-                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
-                        meter: Some(meter.id.clone()),
-                    }),
-                    tax_behavior: None,
-                    tiers: None,
-                    tiers_mode: None,
-                    transfer_lookup_key: None,
-                    transform_quantity: None,
-                    unit_amount: None,
-                    unit_amount_decimal: Some(&format!(
-                        "{:.12}",
-                        price_per_million_tokens.0 as f64 / 1_000_000f64
-                    )),
-                },
-            )
-            .await?;
-            state
-                .price_ids_by_meter_id
-                .insert(meter.id, price.id.clone());
-            price.id
-        };
-
-        Ok(StripeBillingPrice {
-            id: price_id,
-            meter_event_name: meter_event_name.to_string(),
-        })
-    }
-
     pub async fn subscribe_to_price(
         &self,
         subscription_id: &stripe::SubscriptionId,
@@ -283,142 +135,6 @@ impl StripeBilling {
         Ok(())
     }
 
-    pub async fn subscribe_to_model(
-        &self,
-        subscription_id: &stripe::SubscriptionId,
-        model: &StripeModelTokenPrices,
-    ) -> Result<()> {
-        let subscription =
-            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
-
-        let mut items = Vec::new();
-
-        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
-            items.push(stripe::UpdateSubscriptionItems {
-                price: Some(model.input_tokens_price.id.to_string()),
-                ..Default::default()
-            });
-        }
-
-        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
-        {
-            items.push(stripe::UpdateSubscriptionItems {
-                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
-                ..Default::default()
-            });
-        }
-
-        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
-            items.push(stripe::UpdateSubscriptionItems {
-                price: Some(model.input_cache_read_tokens_price.id.to_string()),
-                ..Default::default()
-            });
-        }
-
-        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
-            items.push(stripe::UpdateSubscriptionItems {
-                price: Some(model.output_tokens_price.id.to_string()),
-                ..Default::default()
-            });
-        }
-
-        if !items.is_empty() {
-            items.extend(subscription.items.data.iter().map(|item| {
-                stripe::UpdateSubscriptionItems {
-                    id: Some(item.id.to_string()),
-                    ..Default::default()
-                }
-            }));
-
-            stripe::Subscription::update(
-                &self.client,
-                subscription_id,
-                stripe::UpdateSubscription {
-                    items: Some(items),
-                    ..Default::default()
-                },
-            )
-            .await?;
-        }
-
-        Ok(())
-    }
-
-    pub async fn bill_model_token_usage(
-        &self,
-        customer_id: &stripe::CustomerId,
-        model: &StripeModelTokenPrices,
-        event: &llm::db::billing_event::Model,
-    ) -> Result<()> {
-        let timestamp = Utc::now().timestamp();
-
-        if event.input_tokens > 0 {
-            StripeMeterEvent::create(
-                &self.client,
-                StripeCreateMeterEventParams {
-                    identifier: &format!("input_tokens/{}", event.idempotency_key),
-                    event_name: &model.input_tokens_price.meter_event_name,
-                    payload: StripeCreateMeterEventPayload {
-                        value: event.input_tokens as u64,
-                        stripe_customer_id: customer_id,
-                    },
-                    timestamp: Some(timestamp),
-                },
-            )
-            .await?;
-        }
-
-        if event.input_cache_creation_tokens > 0 {
-            StripeMeterEvent::create(
-                &self.client,
-                StripeCreateMeterEventParams {
-                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
-                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
-                    payload: StripeCreateMeterEventPayload {
-                        value: event.input_cache_creation_tokens as u64,
-                        stripe_customer_id: customer_id,
-                    },
-                    timestamp: Some(timestamp),
-                },
-            )
-            .await?;
-        }
-
-        if event.input_cache_read_tokens > 0 {
-            StripeMeterEvent::create(
-                &self.client,
-                StripeCreateMeterEventParams {
-                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
-                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
-                    payload: StripeCreateMeterEventPayload {
-                        value: event.input_cache_read_tokens as u64,
-                        stripe_customer_id: customer_id,
-                    },
-                    timestamp: Some(timestamp),
-                },
-            )
-            .await?;
-        }
-
-        if event.output_tokens > 0 {
-            StripeMeterEvent::create(
-                &self.client,
-                StripeCreateMeterEventParams {
-                    identifier: &format!("output_tokens/{}", event.idempotency_key),
-                    event_name: &model.output_tokens_price.meter_event_name,
-                    payload: StripeCreateMeterEventPayload {
-                        value: event.output_tokens as u64,
-                        stripe_customer_id: customer_id,
-                    },
-                    timestamp: Some(timestamp),
-                },
-            )
-            .await?;
-        }
-
-        Ok(())
-    }
-
     pub async fn bill_model_request_usage(
         &self,
         customer_id: &stripe::CustomerId,
@@ -445,47 +161,6 @@ impl StripeBilling {
         Ok(())
     }
 
-    pub async fn checkout(
-        &self,
-        customer_id: stripe::CustomerId,
-        github_login: &str,
-        model: &StripeModelTokenPrices,
-        success_url: &str,
-    ) -> Result<String> {
-        let first_of_next_month = Utc::now()
-            .checked_add_months(chrono::Months::new(1))
-            .unwrap()
-            .with_day(1)
-            .unwrap();
-
-        let mut params = stripe::CreateCheckoutSession::new();
-        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
-        params.customer = Some(customer_id);
-        params.client_reference_id = Some(github_login);
-        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
-            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
-            ..Default::default()
-        });
-        params.line_items = Some(
-            [
-                &model.input_tokens_price.id,
-                &model.input_cache_creation_tokens_price.id,
-                &model.input_cache_read_tokens_price.id,
-                &model.output_tokens_price.id,
-            ]
-            .into_iter()
-            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
-                price: Some(price_id.to_string()),
-                ..Default::default()
-            })
-            .collect(),
-        );
-        params.success_url = Some(success_url);
-
-        let session = stripe::CheckoutSession::create(&self.client, params).await?;
-        Ok(session.url.context("no checkout session URL")?)
-    }
-
     pub async fn checkout_with_zed_pro(
         &self,
         customer_id: stripe::CustomerId,
@@ -587,18 +262,6 @@ impl StripeBilling {
     }
 }
 
-#[derive(Serialize)]
-struct DefaultAggregation {
-    formula: &'static str,
-}
-
-#[derive(Serialize)]
-struct StripeCreateMeterParams<'a> {
-    default_aggregation: DefaultAggregation,
-    display_name: String,
-    event_name: &'a str,
-}
-
 #[derive(Clone, Deserialize)]
 struct StripeMeter {
     id: String,
@@ -606,13 +269,6 @@ struct StripeMeter {
 }
 
 impl StripeMeter {
-    pub fn create(
-        client: &stripe::Client,
-        params: StripeCreateMeterParams,
-    ) -> stripe::Response<Self> {
-        client.post_form("/billing/meters", params)
-    }
-
     pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
         #[derive(Serialize)]
         struct Params {