collab: Add usage-based billing for LLM interactions (#19081)

Marshall Bowers , Antonio Scandurra , Antonio , Richard , and Richard Feldman created

This PR adds usage-based billing for LLM interactions in the Assistant.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>

Change summary

Cargo.lock                                                            |   5 
Cargo.toml                                                            |   3 
crates/collab/migrations_llm/20241010151249_create_billing_events.sql |  12 
crates/collab/src/api/billing.rs                                      | 203 
crates/collab/src/db/queries/billing_subscriptions.rs                 |  38 
crates/collab/src/lib.rs                                              |  29 
crates/collab/src/llm.rs                                              |  65 
crates/collab/src/llm/db.rs                                           |   2 
crates/collab/src/llm/db/ids.rs                                       |   1 
crates/collab/src/llm/db/queries.rs                                   |   1 
crates/collab/src/llm/db/queries/billing_events.rs                    |  31 
crates/collab/src/llm/db/queries/usages.rs                            | 133 
crates/collab/src/llm/db/tables.rs                                    |   1 
crates/collab/src/llm/db/tables/billing_event.rs                      |  37 
crates/collab/src/llm/db/tables/model.rs                              |   8 
crates/collab/src/llm/db/tests/usage_tests.rs                         | 192 
crates/collab/src/main.rs                                             |   2 
crates/collab/src/rpc.rs                                              |   9 
crates/collab/src/stripe_billing.rs                                   | 427 
crates/collab/src/tests/test_server.rs                                |   3 
20 files changed, 920 insertions(+), 282 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -839,9 +839,8 @@ dependencies = [
 
 [[package]]
 name = "async-stripe"
-version = "0.39.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "58d670cf4d47a1b8ffef54286a5625382e360a34ee76902fd93ad8c7032a0c30"
+version = "0.40.0"
+source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
 dependencies = [
  "chrono",
  "futures-util",

Cargo.toml 🔗

@@ -480,7 +480,8 @@ which = "6.0.0"
 wit-component = "0.201"
 
 [workspace.dependencies.async-stripe]
-version = "0.39"
+git = "https://github.com/zed-industries/async-stripe"
+rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
 default-features = false
 features = [
     "runtime-tokio-hyper-rustls",

crates/collab/migrations_llm/20241010151249_create_billing_events.sql 🔗

@@ -0,0 +1,12 @@
+create table billing_events (
+    id serial primary key,
+    idempotency_key uuid not null default gen_random_uuid(),
+    user_id integer not null,
+    model_id integer not null references models (id) on delete cascade,
+    input_tokens bigint not null default 0,
+    input_cache_creation_tokens bigint not null default 0,
+    input_cache_read_tokens bigint not null default 0,
+    output_tokens bigint not null default 0
+);
+
+create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);

crates/collab/src/api/billing.rs 🔗

@@ -1,7 +1,3 @@
-use std::str::FromStr;
-use std::sync::Arc;
-use std::time::Duration;
-
 use anyhow::{anyhow, bail, Context};
 use axum::{
     extract::{self, Query},
@@ -9,28 +5,35 @@ use axum::{
     Extension, Json, Router,
 };
 use chrono::{DateTime, SecondsFormat, Utc};
+use collections::HashSet;
 use reqwest::StatusCode;
 use sea_orm::ActiveValue;
 use serde::{Deserialize, Serialize};
+use std::{str::FromStr, sync::Arc, time::Duration};
 use stripe::{
-    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
-    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
+    BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
+    CreateBillingPortalSessionFlowDataAfterCompletion,
     CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
-    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
-    CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
-    Subscription, SubscriptionId, SubscriptionStatus,
+    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
+    EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
 };
 use util::ResultExt;
 
-use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
-use crate::db::{
-    billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
-    CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
-    UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
-};
-use crate::llm::db::LlmDatabase;
-use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
+use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
 use crate::rpc::ResultExt as _;
+use crate::{
+    db::{
+        billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
+        CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
+        UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
+        UpdateBillingSubscriptionParams,
+    },
+    stripe_billing::StripeBilling,
+};
+use crate::{
+    db::{billing_subscription::StripeSubscriptionStatus, UserId},
+    llm::db::LlmDatabase,
+};
 use crate::{AppState, Error, Result};
 
 pub fn router() -> Router {
@@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody {
 
 async fn update_billing_preferences(
     Extension(app): Extension<Arc<AppState>>,
+    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
     extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 ) -> Result<Json<BillingPreferencesResponse>> {
     let user = app
@@ -119,6 +123,8 @@ async fn update_billing_preferences(
                 .await?
         };
 
+    rpc_server.refresh_llm_tokens_for_user(user.id).await;
+
     Ok(Json(BillingPreferencesResponse {
         max_monthly_llm_usage_spending_in_cents: billing_preferences
             .max_monthly_llm_usage_spending_in_cents,
@@ -197,12 +203,15 @@ async fn create_billing_subscription(
         .await?
         .ok_or_else(|| anyhow!("user not found"))?;
 
-    let Some((stripe_client, stripe_access_price_id)) = app
-        .stripe_client
-        .clone()
-        .zip(app.config.stripe_llm_access_price_id.clone())
-    else {
-        log::error!("failed to retrieve Stripe client or price ID");
+    let Some(stripe_client) = app.stripe_client.clone() else {
+        log::error!("failed to retrieve Stripe client");
+        Err(Error::http(
+            StatusCode::NOT_IMPLEMENTED,
+            "not supported".into(),
+        ))?
+    };
+    let Some(llm_db) = app.llm_db.clone() else {
+        log::error!("failed to retrieve LLM database");
         Err(Error::http(
             StatusCode::NOT_IMPLEMENTED,
             "not supported".into(),
@@ -226,26 +235,15 @@ async fn create_billing_subscription(
             customer.id
         };
 
-    let checkout_session = {
-        let mut params = CreateCheckoutSession::new();
-        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
-        params.customer = Some(customer_id);
-        params.client_reference_id = Some(user.github_login.as_str());
-        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
-            price: Some(stripe_access_price_id.to_string()),
-            quantity: Some(1),
-            ..Default::default()
-        }]);
-        let success_url = format!("{}/account", app.config.zed_dot_dev_url());
-        params.success_url = Some(&success_url);
-
-        CheckoutSession::create(&stripe_client, params).await?
-    };
-
+    let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
+    let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
+    let stripe_model = stripe_billing.register_model(default_model).await?;
+    let success_url = format!("{}/account", app.config.zed_dot_dev_url());
+    let checkout_session_url = stripe_billing
+        .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
+        .await?;
     Ok(Json(CreateBillingSubscriptionResponse {
-        checkout_session_url: checkout_session
-            .url
-            .ok_or_else(|| anyhow!("no checkout session URL"))?,
+        checkout_session_url,
     }))
 }
 
@@ -715,15 +713,15 @@ async fn find_or_create_billing_customer(
     Ok(Some(billing_customer))
 }
 
-const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
+const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
 
-pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
+pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
     let Some(stripe_client) = app.stripe_client.clone() else {
         log::warn!("failed to retrieve Stripe client");
         return;
     };
-    let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
-        log::warn!("failed to retrieve Stripe LLM usage price ID");
+    let Some(llm_db) = app.llm_db.clone() else {
+        log::warn!("failed to retrieve LLM database");
         return;
     };
 
@@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
         let executor = executor.clone();
         async move {
             loop {
-                sync_with_stripe(
-                    &app,
-                    &llm_db,
-                    &stripe_client,
-                    stripe_llm_usage_price_id.clone(),
-                )
-                .await
-                .trace_err();
-
+                sync_with_stripe(&app, &llm_db, &stripe_client)
+                    .await
+                    .trace_err();
                 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
             }
         }
@@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
 
 async fn sync_with_stripe(
     app: &Arc<AppState>,
-    llm_db: &LlmDatabase,
-    stripe_client: &stripe::Client,
-    stripe_llm_usage_price_id: Arc<str>,
+    llm_db: &Arc<LlmDatabase>,
+    stripe_client: &Arc<stripe::Client>,
 ) -> anyhow::Result<()> {
-    let subscriptions = app.db.get_active_billing_subscriptions().await?;
-
-    for (customer, subscription) in subscriptions {
-        update_stripe_subscription(
-            llm_db,
-            stripe_client,
-            &stripe_llm_usage_price_id,
-            customer,
-            subscription,
-        )
-        .await
-        .log_err();
-    }
-
-    Ok(())
-}
-
-async fn update_stripe_subscription(
-    llm_db: &LlmDatabase,
-    stripe_client: &stripe::Client,
-    stripe_llm_usage_price_id: &Arc<str>,
-    customer: billing_customer::Model,
-    subscription: billing_subscription::Model,
-) -> Result<(), anyhow::Error> {
-    let monthly_spending = llm_db
-        .get_user_spending_for_month(customer.user_id, Utc::now())
-        .await?;
-    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
-        .context("failed to parse subscription ID")?;
-
-    let monthly_spending_over_free_tier =
-        monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
-
-    let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
-    let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
-
-    let mut update_params = stripe::UpdateSubscription {
-        proration_behavior: Some(
-            stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
-        ),
-        ..Default::default()
-    };
-
-    if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
-        item.price.as_ref().map_or(false, |price| {
-            price.id == stripe_llm_usage_price_id.as_ref()
-        })
-    }) {
-        update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
-            id: Some(existing_item.id.to_string()),
-            quantity: Some(new_quantity as u64),
-            ..Default::default()
-        }]);
-    } else {
-        update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
-            price: Some(stripe_llm_usage_price_id.to_string()),
-            quantity: Some(new_quantity as u64),
-            ..Default::default()
-        }]);
+    let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
+
+    let events = llm_db.get_billing_events().await?;
+    let user_ids = events
+        .iter()
+        .map(|(event, _)| event.user_id)
+        .collect::<HashSet<UserId>>();
+    let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
+
+    for (event, model) in events {
+        let Some((stripe_db_customer, stripe_db_subscription)) =
+            stripe_subscriptions.get(&event.user_id)
+        else {
+            tracing::warn!(
+                user_id = event.user_id.0,
+                "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
+            );
+            continue;
+        };
+        let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
+            .stripe_subscription_id
+            .parse()
+            .context("failed to parse stripe subscription id from db")?;
+        let stripe_customer_id: stripe::CustomerId = stripe_db_customer
+            .stripe_customer_id
+            .parse()
+            .context("failed to parse stripe customer id from db")?;
+
+        let stripe_model = stripe_billing.register_model(&model).await?;
+        stripe_billing
+            .subscribe_to_model(&stripe_subscription_id, &stripe_model)
+            .await?;
+        stripe_billing
+            .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
+            .await?;
+        llm_db.consume_billing_event(event.id).await?;
     }
 
-    Subscription::update(stripe_client, &subscription_id, update_params).await?;
     Ok(())
 }

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -114,23 +114,31 @@ impl Database {
 
     pub async fn get_active_billing_subscriptions(
         &self,
-    ) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
-        self.transaction(|tx| async move {
-            let mut result = Vec::new();
-            let mut rows = billing_subscription::Entity::find()
-                .inner_join(billing_customer::Entity)
-                .select_also(billing_customer::Entity)
-                .order_by_asc(billing_subscription::Column::Id)
-                .stream(&*tx)
-                .await?;
-
-            while let Some(row) = rows.next().await {
-                if let (subscription, Some(customer)) = row? {
-                    result.push((customer, subscription));
+        user_ids: HashSet<UserId>,
+    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
+        self.transaction(|tx| {
+            let user_ids = user_ids.clone();
+            async move {
+                let mut rows = billing_subscription::Entity::find()
+                    .inner_join(billing_customer::Entity)
+                    .select_also(billing_customer::Entity)
+                    .filter(billing_customer::Column::UserId.is_in(user_ids))
+                    .filter(
+                        billing_subscription::Column::StripeSubscriptionStatus
+                            .eq(StripeSubscriptionStatus::Active),
+                    )
+                    .order_by_asc(billing_subscription::Column::Id)
+                    .stream(&*tx)
+                    .await?;
+
+                let mut subscriptions = HashMap::default();
+                while let Some(row) = rows.next().await {
+                    if let (subscription, Some(customer)) = row? {
+                        subscriptions.insert(customer.user_id, (customer, subscription));
+                    }
                 }
+                Ok(subscriptions)
             }
-
-            Ok(result)
         })
         .await
     }

crates/collab/src/lib.rs 🔗

@@ -10,6 +10,7 @@ pub mod migrations;
 mod rate_limiter;
 pub mod rpc;
 pub mod seed;
+pub mod stripe_billing;
 pub mod user_backfiller;
 
 #[cfg(test)]
@@ -24,6 +25,7 @@ use axum::{
 pub use cents::*;
 use db::{ChannelId, Database};
 use executor::Executor;
+use llm::db::LlmDatabase;
 pub use rate_limiter::*;
 use serde::Deserialize;
 use std::{path::PathBuf, sync::Arc};
@@ -176,8 +178,6 @@ pub struct Config {
     pub slack_panics_webhook: Option<String>,
     pub auto_join_channel_id: Option<ChannelId>,
     pub stripe_api_key: Option<String>,
-    pub stripe_llm_access_price_id: Option<Arc<str>>,
-    pub stripe_llm_usage_price_id: Option<Arc<str>>,
     pub supermaven_admin_api_key: Option<Arc<str>>,
     pub user_backfiller_github_access_token: Option<Arc<str>>,
 }
@@ -197,7 +197,7 @@ impl Config {
     }
 
     pub fn is_llm_billing_enabled(&self) -> bool {
-        self.stripe_llm_usage_price_id.is_some()
+        self.stripe_api_key.is_some()
     }
 
     #[cfg(test)]
@@ -238,8 +238,6 @@ impl Config {
             migrations_path: None,
             seed_path: None,
             stripe_api_key: None,
-            stripe_llm_access_price_id: None,
-            stripe_llm_usage_price_id: None,
             supermaven_admin_api_key: None,
             user_backfiller_github_access_token: None,
         }
@@ -272,6 +270,7 @@ impl ServiceMode {
 
 pub struct AppState {
     pub db: Arc<Database>,
+    pub llm_db: Option<Arc<LlmDatabase>>,
     pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
     pub blob_store_client: Option<aws_sdk_s3::Client>,
     pub stripe_client: Option<Arc<stripe::Client>>,
@@ -288,6 +287,20 @@ impl AppState {
         let mut db = Database::new(db_options, Executor::Production).await?;
         db.initialize_notification_kinds().await?;
 
+        let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
+            .llm_database_url
+            .clone()
+            .zip(config.llm_database_max_connections)
+        {
+            let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
+            llm_db_options.max_connections(llm_database_max_connections);
+            let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
+            llm_db.initialize().await?;
+            Some(Arc::new(llm_db))
+        } else {
+            None
+        };
+
         let live_kit_client = if let Some(((server, key), secret)) = config
             .live_kit_server
             .as_ref()
@@ -306,9 +319,10 @@ impl AppState {
         let db = Arc::new(db);
         let this = Self {
             db: db.clone(),
+            llm_db,
             live_kit_client,
             blob_store_client: build_blob_store_client(&config).await.log_err(),
-            stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(),
+            stripe_client: build_stripe_client(&config).map(Arc::new).log_err(),
             rate_limiter: Arc::new(RateLimiter::new(db)),
             executor,
             clickhouse_client: config
@@ -321,12 +335,11 @@ impl AppState {
     }
 }
 
-async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
+fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
     let api_key = config
         .stripe_api_key
         .as_ref()
         .ok_or_else(|| anyhow!("missing stripe_api_key"))?;
-
     Ok(stripe::Client::new(api_key))
 }
 

crates/collab/src/llm.rs 🔗

@@ -20,13 +20,14 @@ use axum::{
 };
 use chrono::{DateTime, Duration, Utc};
 use collections::HashMap;
+use db::TokenUsage;
 use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 use futures::{Stream, StreamExt as _};
 use isahc_http_client::IsahcHttpClient;
-use rpc::ListModelsResponse;
 use rpc::{
     proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 };
+use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
 use std::{
     pin::Pin,
     sync::Arc,
@@ -418,10 +419,7 @@ async fn perform_completion(
         claims,
         provider: params.provider,
         model,
-        input_tokens: 0,
-        output_tokens: 0,
-        cache_creation_input_tokens: 0,
-        cache_read_input_tokens: 0,
+        tokens: TokenUsage::default(),
         inner_stream: stream,
     })))
 }
@@ -476,6 +474,19 @@ async fn check_usage_limit(
                     "Maximum spending limit reached for this month.".to_string(),
                 ));
             }
+
+            if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
+                return Err(Error::Http(
+                    StatusCode::FORBIDDEN,
+                    "Maximum spending limit reached for this month.".to_string(),
+                    [(
+                        HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
+                        HeaderValue::from_static("true"),
+                    )]
+                    .into_iter()
+                    .collect(),
+                ));
+            }
         }
     }
 
@@ -598,10 +609,7 @@ struct TokenCountingStream<S> {
     claims: LlmTokenClaims,
     provider: LanguageModelProvider,
     model: String,
-    input_tokens: usize,
-    output_tokens: usize,
-    cache_creation_input_tokens: usize,
-    cache_read_input_tokens: usize,
+    tokens: TokenUsage,
     inner_stream: S,
 }
 
@@ -615,10 +623,10 @@ where
         match Pin::new(&mut self.inner_stream).poll_next(cx) {
             Poll::Ready(Some(Ok(mut chunk))) => {
                 chunk.bytes.push(b'\n');
-                self.input_tokens += chunk.input_tokens;
-                self.output_tokens += chunk.output_tokens;
-                self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
-                self.cache_read_input_tokens += chunk.cache_read_input_tokens;
+                self.tokens.input += chunk.input_tokens;
+                self.tokens.output += chunk.output_tokens;
+                self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
+                self.tokens.input_cache_read += chunk.cache_read_input_tokens;
                 Poll::Ready(Some(Ok(chunk.bytes)))
             }
             Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
@@ -634,10 +642,7 @@ impl<S> Drop for TokenCountingStream<S> {
         let claims = self.claims.clone();
         let provider = self.provider;
         let model = std::mem::take(&mut self.model);
-        let input_token_count = self.input_tokens;
-        let output_token_count = self.output_tokens;
-        let cache_creation_input_token_count = self.cache_creation_input_tokens;
-        let cache_read_input_token_count = self.cache_read_input_tokens;
+        let tokens = self.tokens;
         self.state.executor.spawn_detached(async move {
             let usage = state
                 .db
@@ -646,10 +651,9 @@ impl<S> Drop for TokenCountingStream<S> {
                     claims.is_staff,
                     provider,
                     &model,
-                    input_token_count,
-                    cache_creation_input_token_count,
-                    cache_read_input_token_count,
-                    output_token_count,
+                    tokens,
+                    claims.has_llm_subscription,
+                    Cents(claims.max_monthly_spend_in_cents),
                     Utc::now(),
                 )
                 .await
@@ -679,22 +683,23 @@ impl<S> Drop for TokenCountingStream<S> {
                             },
                             model,
                             provider: provider.to_string(),
-                            input_token_count: input_token_count as u64,
-                            cache_creation_input_token_count: cache_creation_input_token_count
-                                as u64,
-                            cache_read_input_token_count: cache_read_input_token_count as u64,
-                            output_token_count: output_token_count as u64,
+                            input_token_count: tokens.input as u64,
+                            cache_creation_input_token_count: tokens.input_cache_creation as u64,
+                            cache_read_input_token_count: tokens.input_cache_read as u64,
+                            output_token_count: tokens.output as u64,
                             requests_this_minute: usage.requests_this_minute as u64,
                             tokens_this_minute: usage.tokens_this_minute as u64,
                             tokens_this_day: usage.tokens_this_day as u64,
-                            input_tokens_this_month: usage.input_tokens_this_month as u64,
+                            input_tokens_this_month: usage.tokens_this_month.input as u64,
                             cache_creation_input_tokens_this_month: usage
-                                .cache_creation_input_tokens_this_month
+                                .tokens_this_month
+                                .input_cache_creation
                                 as u64,
                             cache_read_input_tokens_this_month: usage
-                                .cache_read_input_tokens_this_month
+                                .tokens_this_month
+                                .input_cache_read
                                 as u64,
-                            output_tokens_this_month: usage.output_tokens_this_month as u64,
+                            output_tokens_this_month: usage.tokens_this_month.output as u64,
                             spending_this_month: usage.spending_this_month.0 as u64,
                             lifetime_spending: usage.lifetime_spending.0 as u64,
                         },

crates/collab/src/llm/db.rs 🔗

@@ -20,7 +20,7 @@ use std::future::Future;
 use std::sync::Arc;
 
 use anyhow::anyhow;
-pub use queries::usages::ActiveUserCount;
+pub use queries::usages::{ActiveUserCount, TokenUsage};
 use sea_orm::prelude::*;
 pub use sea_orm::ConnectOptions;
 use sea_orm::{

crates/collab/src/llm/db/ids.rs 🔗

@@ -8,3 +8,4 @@ id_type!(ProviderId);
 id_type!(UsageId);
 id_type!(UsageMeasureId);
 id_type!(RevokedAccessTokenId);
+id_type!(BillingEventId);

crates/collab/src/llm/db/queries/billing_events.rs 🔗

@@ -0,0 +1,31 @@
+use super::*;
+use crate::Result;
+use anyhow::Context as _;
+
+impl LlmDatabase {
+    pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
+        self.transaction(|tx| async move {
+            let events_with_models = billing_event::Entity::find()
+                .find_also_related(model::Entity)
+                .all(&*tx)
+                .await?;
+            events_with_models
+                .into_iter()
+                .map(|(event, model)| {
+                    let model =
+                        model.context("could not find model associated with billing event")?;
+                    Ok((event, model))
+                })
+                .collect()
+        })
+        .await
+    }
+
+    pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
+        self.transaction(|tx| async move {
+            billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
+            Ok(())
+        })
+        .await
+    }
+}

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -1,5 +1,5 @@
-use crate::db::UserId;
 use crate::llm::Cents;
+use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
 use chrono::{Datelike, Duration};
 use futures::StreamExt as _;
 use rpc::LanguageModelProvider;
@@ -9,15 +9,26 @@ use strum::IntoEnumIterator as _;
 
 use super::*;
 
+#[derive(Debug, PartialEq, Clone, Copy, Default)]
+pub struct TokenUsage {
+    pub input: usize,
+    pub input_cache_creation: usize,
+    pub input_cache_read: usize,
+    pub output: usize,
+}
+
+impl TokenUsage {
+    pub fn total(&self) -> usize {
+        self.input + self.input_cache_creation + self.input_cache_read + self.output
+    }
+}
+
 #[derive(Debug, PartialEq, Clone, Copy)]
 pub struct Usage {
     pub requests_this_minute: usize,
     pub tokens_this_minute: usize,
     pub tokens_this_day: usize,
-    pub input_tokens_this_month: usize,
-    pub cache_creation_input_tokens_this_month: usize,
-    pub cache_read_input_tokens_this_month: usize,
-    pub output_tokens_this_month: usize,
+    pub tokens_this_month: TokenUsage,
     pub spending_this_month: Cents,
     pub lifetime_spending: Cents,
 }
@@ -257,18 +268,20 @@ impl LlmDatabase {
                 requests_this_minute,
                 tokens_this_minute,
                 tokens_this_day,
-                input_tokens_this_month: monthly_usage
-                    .as_ref()
-                    .map_or(0, |usage| usage.input_tokens as usize),
-                cache_creation_input_tokens_this_month: monthly_usage
-                    .as_ref()
-                    .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
-                cache_read_input_tokens_this_month: monthly_usage
-                    .as_ref()
-                    .map_or(0, |usage| usage.cache_read_input_tokens as usize),
-                output_tokens_this_month: monthly_usage
-                    .as_ref()
-                    .map_or(0, |usage| usage.output_tokens as usize),
+                tokens_this_month: TokenUsage {
+                    input: monthly_usage
+                        .as_ref()
+                        .map_or(0, |usage| usage.input_tokens as usize),
+                    input_cache_creation: monthly_usage
+                        .as_ref()
+                        .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
+                    input_cache_read: monthly_usage
+                        .as_ref()
+                        .map_or(0, |usage| usage.cache_read_input_tokens as usize),
+                    output: monthly_usage
+                        .as_ref()
+                        .map_or(0, |usage| usage.output_tokens as usize),
+                },
                 spending_this_month,
                 lifetime_spending,
             })
@@ -283,10 +296,9 @@ impl LlmDatabase {
         is_staff: bool,
         provider: LanguageModelProvider,
         model_name: &str,
-        input_token_count: usize,
-        cache_creation_input_tokens: usize,
-        cache_read_input_tokens: usize,
-        output_token_count: usize,
+        tokens: TokenUsage,
+        has_llm_subscription: bool,
+        max_monthly_spend: Cents,
         now: DateTimeUtc,
     ) -> Result<Usage> {
         self.transaction(|tx| async move {
@@ -313,10 +325,6 @@ impl LlmDatabase {
                     &tx,
                 )
                 .await?;
-            let total_token_count = input_token_count
-                + cache_read_input_tokens
-                + cache_creation_input_tokens
-                + output_token_count;
             let tokens_this_minute = self
                 .update_usage_for_measure(
                     user_id,
@@ -325,7 +333,7 @@ impl LlmDatabase {
                     &usages,
                     UsageMeasure::TokensPerMinute,
                     now,
-                    total_token_count,
+                    tokens.total(),
                     &tx,
                 )
                 .await?;
@@ -337,7 +345,7 @@ impl LlmDatabase {
                     &usages,
                     UsageMeasure::TokensPerDay,
                     now,
-                    total_token_count,
+                    tokens.total(),
                     &tx,
                 )
                 .await?;
@@ -361,18 +369,14 @@ impl LlmDatabase {
                 Some(usage) => {
                     monthly_usage::Entity::update(monthly_usage::ActiveModel {
                         id: ActiveValue::unchanged(usage.id),
-                        input_tokens: ActiveValue::set(
-                            usage.input_tokens + input_token_count as i64,
-                        ),
+                        input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
                         cache_creation_input_tokens: ActiveValue::set(
-                            usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
+                            usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
                         ),
                         cache_read_input_tokens: ActiveValue::set(
-                            usage.cache_read_input_tokens + cache_read_input_tokens as i64,
-                        ),
-                        output_tokens: ActiveValue::set(
-                            usage.output_tokens + output_token_count as i64,
+                            usage.cache_read_input_tokens + tokens.input_cache_read as i64,
                         ),
+                        output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
                         ..Default::default()
                     })
                     .exec(&*tx)
@@ -384,12 +388,12 @@ impl LlmDatabase {
                         model_id: ActiveValue::set(model.id),
                         month: ActiveValue::set(month),
                         year: ActiveValue::set(year),
-                        input_tokens: ActiveValue::set(input_token_count as i64),
+                        input_tokens: ActiveValue::set(tokens.input as i64),
                         cache_creation_input_tokens: ActiveValue::set(
-                            cache_creation_input_tokens as i64,
+                            tokens.input_cache_creation as i64,
                         ),
-                        cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
-                        output_tokens: ActiveValue::set(output_token_count as i64),
+                        cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
+                        output_tokens: ActiveValue::set(tokens.output as i64),
                         ..Default::default()
                     }
                     .insert(&*tx)
@@ -405,6 +409,26 @@ impl LlmDatabase {
                 monthly_usage.output_tokens as usize,
             );
 
+            if spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
+                && has_llm_subscription
+                && spending_this_month <= max_monthly_spend
+            {
+                billing_event::ActiveModel {
+                    id: ActiveValue::not_set(),
+                    idempotency_key: ActiveValue::not_set(),
+                    user_id: ActiveValue::set(user_id),
+                    model_id: ActiveValue::set(model.id),
+                    input_tokens: ActiveValue::set(tokens.input as i64),
+                    input_cache_creation_tokens: ActiveValue::set(
+                        tokens.input_cache_creation as i64,
+                    ),
+                    input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
+                    output_tokens: ActiveValue::set(tokens.output as i64),
+                }
+                .insert(&*tx)
+                .await?;
+            }
+
             // Update lifetime usage
             let lifetime_usage = lifetime_usage::Entity::find()
                 .filter(
@@ -419,18 +443,14 @@ impl LlmDatabase {
                 Some(usage) => {
                     lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
                         id: ActiveValue::unchanged(usage.id),
-                        input_tokens: ActiveValue::set(
-                            usage.input_tokens + input_token_count as i64,
-                        ),
+                        input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
                         cache_creation_input_tokens: ActiveValue::set(
-                            usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
+                            usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
                         ),
                         cache_read_input_tokens: ActiveValue::set(
-                            usage.cache_read_input_tokens + cache_read_input_tokens as i64,
-                        ),
-                        output_tokens: ActiveValue::set(
-                            usage.output_tokens + output_token_count as i64,
+                            usage.cache_read_input_tokens + tokens.input_cache_read as i64,
                         ),
+                        output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
                         ..Default::default()
                     })
                     .exec(&*tx)
@@ -440,12 +460,12 @@ impl LlmDatabase {
                     lifetime_usage::ActiveModel {
                         user_id: ActiveValue::set(user_id),
                         model_id: ActiveValue::set(model.id),
-                        input_tokens: ActiveValue::set(input_token_count as i64),
+                        input_tokens: ActiveValue::set(tokens.input as i64),
                         cache_creation_input_tokens: ActiveValue::set(
-                            cache_creation_input_tokens as i64,
+                            tokens.input_cache_creation as i64,
                         ),
-                        cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
-                        output_tokens: ActiveValue::set(output_token_count as i64),
+                        cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
+                        output_tokens: ActiveValue::set(tokens.output as i64),
                         ..Default::default()
                     }
                     .insert(&*tx)
@@ -465,11 +485,12 @@ impl LlmDatabase {
                 requests_this_minute,
                 tokens_this_minute,
                 tokens_this_day,
-                input_tokens_this_month: monthly_usage.input_tokens as usize,
-                cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
-                    as usize,
-                cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
-                output_tokens_this_month: monthly_usage.output_tokens as usize,
+                tokens_this_month: TokenUsage {
+                    input: monthly_usage.input_tokens as usize,
+                    input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
+                    input_cache_read: monthly_usage.cache_read_input_tokens as usize,
+                    output: monthly_usage.output_tokens as usize,
+                },
                 spending_this_month,
                 lifetime_spending,
             })

crates/collab/src/llm/db/tables/billing_event.rs 🔗

@@ -0,0 +1,37 @@
+use crate::{
+    db::UserId,
+    llm::db::{BillingEventId, ModelId},
+};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "billing_events")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: BillingEventId,
+    pub idempotency_key: Uuid,
+    pub user_id: UserId,
+    pub model_id: ModelId,
+    pub input_tokens: i64,
+    pub input_cache_creation_tokens: i64,
+    pub input_cache_read_tokens: i64,
+    pub output_tokens: i64,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::model::Entity",
+        from = "Column::ModelId",
+        to = "super::model::Column::Id"
+    )]
+    Model,
+}
+
+impl Related<super::model::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Model.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tables/model.rs 🔗

@@ -29,6 +29,8 @@ pub enum Relation {
     Provider,
     #[sea_orm(has_many = "super::usage::Entity")]
     Usages,
+    #[sea_orm(has_many = "super::billing_event::Entity")]
+    BillingEvents,
 }
 
 impl Related<super::provider::Entity> for Entity {
@@ -43,4 +45,10 @@ impl Related<super::usage::Entity> for Entity {
     }
 }
 
+impl Related<super::billing_event::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::BillingEvents.def()
+    }
+}
+
 impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     db::UserId,
     llm::db::{
         queries::{providers::ModelParams, usages::Usage},
-        LlmDatabase,
+        LlmDatabase, TokenUsage,
     },
     test_llm_db, Cents,
 };
@@ -36,14 +36,42 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
     let user_id = UserId::from_proto(123);
 
     let now = t0;
-    db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now)
-        .await
-        .unwrap();
+    db.record_usage(
+        user_id,
+        false,
+        provider,
+        model,
+        TokenUsage {
+            input: 1000,
+            input_cache_creation: 0,
+            input_cache_read: 0,
+            output: 0,
+        },
+        false,
+        Cents::ZERO,
+        now,
+    )
+    .await
+    .unwrap();
 
     let now = t0 + Duration::seconds(10);
-    db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now)
-        .await
-        .unwrap();
+    db.record_usage(
+        user_id,
+        false,
+        provider,
+        model,
+        TokenUsage {
+            input: 2000,
+            input_cache_creation: 0,
+            input_cache_read: 0,
+            output: 0,
+        },
+        false,
+        Cents::ZERO,
+        now,
+    )
+    .await
+    .unwrap();
 
     let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
     assert_eq!(
@@ -52,10 +80,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 2,
             tokens_this_minute: 3000,
             tokens_this_day: 3000,
-            input_tokens_this_month: 3000,
-            cache_creation_input_tokens_this_month: 0,
-            cache_read_input_tokens_this_month: 0,
-            output_tokens_this_month: 0,
+            tokens_this_month: TokenUsage {
+                input: 3000,
+                input_cache_creation: 0,
+                input_cache_read: 0,
+                output: 0,
+            },
             spending_this_month: Cents::ZERO,
             lifetime_spending: Cents::ZERO,
         }
@@ -69,19 +99,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 1,
             tokens_this_minute: 2000,
             tokens_this_day: 3000,
-            input_tokens_this_month: 3000,
-            cache_creation_input_tokens_this_month: 0,
-            cache_read_input_tokens_this_month: 0,
-            output_tokens_this_month: 0,
+            tokens_this_month: TokenUsage {
+                input: 3000,
+                input_cache_creation: 0,
+                input_cache_read: 0,
+                output: 0,
+            },
             spending_this_month: Cents::ZERO,
             lifetime_spending: Cents::ZERO,
         }
     );
 
     let now = t0 + Duration::seconds(60);
-    db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now)
-        .await
-        .unwrap();
+    db.record_usage(
+        user_id,
+        false,
+        provider,
+        model,
+        TokenUsage {
+            input: 3000,
+            input_cache_creation: 0,
+            input_cache_read: 0,
+            output: 0,
+        },
+        false,
+        Cents::ZERO,
+        now,
+    )
+    .await
+    .unwrap();
 
     let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
     assert_eq!(
@@ -90,10 +136,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 2,
             tokens_this_minute: 5000,
             tokens_this_day: 6000,
-            input_tokens_this_month: 6000,
-            cache_creation_input_tokens_this_month: 0,
-            cache_read_input_tokens_this_month: 0,
-            output_tokens_this_month: 0,
+            tokens_this_month: TokenUsage {
+                input: 6000,
+                input_cache_creation: 0,
+                input_cache_read: 0,
+                output: 0,
+            },
             spending_this_month: Cents::ZERO,
             lifetime_spending: Cents::ZERO,
         }
@@ -108,18 +156,34 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 0,
             tokens_this_minute: 0,
             tokens_this_day: 5000,
-            input_tokens_this_month: 6000,
-            cache_creation_input_tokens_this_month: 0,
-            cache_read_input_tokens_this_month: 0,
-            output_tokens_this_month: 0,
+            tokens_this_month: TokenUsage {
+                input: 6000,
+                input_cache_creation: 0,
+                input_cache_read: 0,
+                output: 0,
+            },
             spending_this_month: Cents::ZERO,
             lifetime_spending: Cents::ZERO,
         }
     );
 
-    db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now)
-        .await
-        .unwrap();
+    db.record_usage(
+        user_id,
+        false,
+        provider,
+        model,
+        TokenUsage {
+            input: 4000,
+            input_cache_creation: 0,
+            input_cache_read: 0,
+            output: 0,
+        },
+        false,
+        Cents::ZERO,
+        now,
+    )
+    .await
+    .unwrap();
 
     let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
     assert_eq!(
@@ -128,10 +192,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 1,
             tokens_this_minute: 4000,
             tokens_this_day: 9000,
-            input_tokens_this_month: 10000,
-            cache_creation_input_tokens_this_month: 0,
-            cache_read_input_tokens_this_month: 0,
-            output_tokens_this_month: 0,
+            tokens_this_month: TokenUsage {
+                input: 10000,
+                input_cache_creation: 0,
+                input_cache_read: 0,
+                output: 0,
+            },
             spending_this_month: Cents::ZERO,
             lifetime_spending: Cents::ZERO,
         }
@@ -143,9 +209,23 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         .with_timezone(&Utc);
 
     // Test cache creation input tokens
-    db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
-        .await
-        .unwrap();
+    db.record_usage(
+        user_id,
+        false,
+        provider,
+        model,
+        TokenUsage {
+            input: 1000,
+            input_cache_creation: 500,
+            input_cache_read: 0,
+            output: 0,
+        },
+        false,
+        Cents::ZERO,
+        now,
+    )
+    .await
+    .unwrap();
 
     let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
     assert_eq!(
@@ -154,19 +234,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 1,
             tokens_this_minute: 1500,
             tokens_this_day: 1500,
-            input_tokens_this_month: 1000,
-            cache_creation_input_tokens_this_month: 500,
-            cache_read_input_tokens_this_month: 0,
-            output_tokens_this_month: 0,
+            tokens_this_month: TokenUsage {
+                input: 1000,
+                input_cache_creation: 500,
+                input_cache_read: 0,
+                output: 0,
+            },
             spending_this_month: Cents::ZERO,
             lifetime_spending: Cents::ZERO,
         }
     );
 
     // Test cache read input tokens
-    db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now)
-        .await
-        .unwrap();
+    db.record_usage(
+        user_id,
+        false,
+        provider,
+        model,
+        TokenUsage {
+            input: 1000,
+            input_cache_creation: 0,
+            input_cache_read: 300,
+            output: 0,
+        },
+        false,
+        Cents::ZERO,
+        now,
+    )
+    .await
+    .unwrap();
 
     let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
     assert_eq!(
@@ -175,10 +271,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
             requests_this_minute: 2,
             tokens_this_minute: 2800,
             tokens_this_day: 2800,
-            input_tokens_this_month: 2000,
-            cache_creation_input_tokens_this_month: 500,
-            cache_read_input_tokens_this_month: 300,
-            output_tokens_this_month: 0,
+            tokens_this_month: TokenUsage {
+                input: 2000,
+                input_cache_creation: 500,
+                input_cache_read: 300,
+                output: 0,
+            },
             spending_this_month: Cents::ZERO,
             lifetime_spending: Cents::ZERO,
         }

crates/collab/src/main.rs 🔗

@@ -157,7 +157,7 @@ async fn main() -> Result<()> {
 
                     if let Some(mut llm_db) = llm_db {
                         llm_db.initialize().await?;
-                        sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
+                        sync_llm_usage_with_stripe_periodically(state.clone());
                     }
 
                     app = app

crates/collab/src/rpc.rs 🔗

@@ -1218,6 +1218,15 @@ impl Server {
         Ok(())
     }
 
+    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
+        let pool = self.connection_pool.lock();
+        for connection_id in pool.user_connection_ids(user_id) {
+            self.peer
+                .send(connection_id, proto::RefreshLlmToken {})
+                .trace_err();
+        }
+    }
+
     pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
         ServerSnapshot {
             connection_pool: ConnectionPoolGuard {

crates/collab/src/stripe_billing.rs 🔗

@@ -0,0 +1,427 @@
+use std::sync::Arc;
+
+use crate::{llm, Cents, Result};
+use anyhow::Context;
+use chrono::Utc;
+use collections::HashMap;
+use serde::{Deserialize, Serialize};
+
+pub struct StripeBilling {
+    meters_by_event_name: HashMap<String, StripeMeter>,
+    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+    client: Arc<stripe::Client>,
+}
+
+pub struct StripeModel {
+    input_tokens_price: StripeBillingPrice,
+    input_cache_creation_tokens_price: StripeBillingPrice,
+    input_cache_read_tokens_price: StripeBillingPrice,
+    output_tokens_price: StripeBillingPrice,
+}
+
+struct StripeBillingPrice {
+    id: stripe::PriceId,
+    meter_event_name: String,
+}
+
+impl StripeBilling {
+    pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
+        let mut meters_by_event_name = HashMap::default();
+        for meter in StripeMeter::list(&client).await?.data {
+            meters_by_event_name.insert(meter.event_name.clone(), meter);
+        }
+
+        let mut price_ids_by_meter_id = HashMap::default();
+        for price in stripe::Price::list(&client, &stripe::ListPrices::default())
+            .await?
+            .data
+        {
+            if let Some(recurring) = price.recurring {
+                if let Some(meter) = recurring.meter {
+                    price_ids_by_meter_id.insert(meter, price.id);
+                }
+            }
+        }
+
+        Ok(Self {
+            meters_by_event_name,
+            price_ids_by_meter_id,
+            client,
+        })
+    }
+
+    pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
+        let input_tokens_price = self
+            .get_or_insert_price(
+                &format!("model_{}/input_tokens", model.id),
+                &format!("{} (Input Tokens)", model.name),
+                Cents::new(model.price_per_million_input_tokens as u32),
+            )
+            .await?;
+        let input_cache_creation_tokens_price = self
+            .get_or_insert_price(
+                &format!("model_{}/input_cache_creation_tokens", model.id),
+                &format!("{} (Input Cache Creation Tokens)", model.name),
+                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
+            )
+            .await?;
+        let input_cache_read_tokens_price = self
+            .get_or_insert_price(
+                &format!("model_{}/input_cache_read_tokens", model.id),
+                &format!("{} (Input Cache Read Tokens)", model.name),
+                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
+            )
+            .await?;
+        let output_tokens_price = self
+            .get_or_insert_price(
+                &format!("model_{}/output_tokens", model.id),
+                &format!("{} (Output Tokens)", model.name),
+                Cents::new(model.price_per_million_output_tokens as u32),
+            )
+            .await?;
+        Ok(StripeModel {
+            input_tokens_price,
+            input_cache_creation_tokens_price,
+            input_cache_read_tokens_price,
+            output_tokens_price,
+        })
+    }
+
+    async fn get_or_insert_price(
+        &mut self,
+        meter_event_name: &str,
+        price_description: &str,
+        price_per_million_tokens: Cents,
+    ) -> Result<StripeBillingPrice> {
+        let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
+            meter.clone()
+        } else {
+            let meter = StripeMeter::create(
+                &self.client,
+                StripeCreateMeterParams {
+                    default_aggregation: DefaultAggregation { formula: "sum" },
+                    display_name: price_description.to_string(),
+                    event_name: meter_event_name,
+                },
+            )
+            .await?;
+            self.meters_by_event_name
+                .insert(meter_event_name.to_string(), meter.clone());
+            meter
+        };
+
+        let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
+            price_id.clone()
+        } else {
+            let price = stripe::Price::create(
+                &self.client,
+                stripe::CreatePrice {
+                    active: Some(true),
+                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
+                    currency: stripe::Currency::USD,
+                    currency_options: None,
+                    custom_unit_amount: None,
+                    expand: &[],
+                    lookup_key: None,
+                    metadata: None,
+                    nickname: None,
+                    product: None,
+                    product_data: Some(stripe::CreatePriceProductData {
+                        id: None,
+                        active: Some(true),
+                        metadata: None,
+                        name: price_description.to_string(),
+                        statement_descriptor: None,
+                        tax_code: None,
+                        unit_label: None,
+                    }),
+                    recurring: Some(stripe::CreatePriceRecurring {
+                        aggregate_usage: None,
+                        interval: stripe::CreatePriceRecurringInterval::Month,
+                        interval_count: None,
+                        trial_period_days: None,
+                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
+                        meter: Some(meter.id.clone()),
+                    }),
+                    tax_behavior: None,
+                    tiers: None,
+                    tiers_mode: None,
+                    transfer_lookup_key: None,
+                    transform_quantity: None,
+                    unit_amount: None,
+                    unit_amount_decimal: Some(&format!(
+                        "{:.12}",
+                        price_per_million_tokens.0 as f64 / 1_000_000f64
+                    )),
+                },
+            )
+            .await?;
+            self.price_ids_by_meter_id
+                .insert(meter.id, price.id.clone());
+            price.id
+        };
+
+        Ok(StripeBillingPrice {
+            id: price_id,
+            meter_event_name: meter_event_name.to_string(),
+        })
+    }
+
+    pub async fn subscribe_to_model(
+        &self,
+        subscription_id: &stripe::SubscriptionId,
+        model: &StripeModel,
+    ) -> Result<()> {
+        let subscription =
+            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+
+        let mut items = Vec::new();
+
+        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
+            items.push(stripe::UpdateSubscriptionItems {
+                price: Some(model.input_tokens_price.id.to_string()),
+                ..Default::default()
+            });
+        }
+
+        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
+        {
+            items.push(stripe::UpdateSubscriptionItems {
+                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
+                ..Default::default()
+            });
+        }
+
+        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
+            items.push(stripe::UpdateSubscriptionItems {
+                price: Some(model.input_cache_read_tokens_price.id.to_string()),
+                ..Default::default()
+            });
+        }
+
+        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
+            items.push(stripe::UpdateSubscriptionItems {
+                price: Some(model.output_tokens_price.id.to_string()),
+                ..Default::default()
+            });
+        }
+
+        if !items.is_empty() {
+            items.extend(subscription.items.data.iter().map(|item| {
+                stripe::UpdateSubscriptionItems {
+                    id: Some(item.id.to_string()),
+                    ..Default::default()
+                }
+            }));
+
+            stripe::Subscription::update(
+                &self.client,
+                subscription_id,
+                stripe::UpdateSubscription {
+                    items: Some(items),
+                    ..Default::default()
+                },
+            )
+            .await?;
+        }
+
+        Ok(())
+    }
+
+    pub async fn bill_model_usage(
+        &self,
+        customer_id: &stripe::CustomerId,
+        model: &StripeModel,
+        event: &llm::db::billing_event::Model,
+    ) -> Result<()> {
+        let timestamp = Utc::now().timestamp();
+
+        if event.input_tokens > 0 {
+            StripeMeterEvent::create(
+                &self.client,
+                StripeCreateMeterEventParams {
+                    identifier: &format!("input_tokens/{}", event.idempotency_key),
+                    event_name: &model.input_tokens_price.meter_event_name,
+                    payload: StripeCreateMeterEventPayload {
+                        value: event.input_tokens as u64,
+                        stripe_customer_id: customer_id,
+                    },
+                    timestamp: Some(timestamp),
+                },
+            )
+            .await?;
+        }
+
+        if event.input_cache_creation_tokens > 0 {
+            StripeMeterEvent::create(
+                &self.client,
+                StripeCreateMeterEventParams {
+                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
+                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
+                    payload: StripeCreateMeterEventPayload {
+                        value: event.input_cache_creation_tokens as u64,
+                        stripe_customer_id: customer_id,
+                    },
+                    timestamp: Some(timestamp),
+                },
+            )
+            .await?;
+        }
+
+        if event.input_cache_read_tokens > 0 {
+            StripeMeterEvent::create(
+                &self.client,
+                StripeCreateMeterEventParams {
+                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
+                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
+                    payload: StripeCreateMeterEventPayload {
+                        value: event.input_cache_read_tokens as u64,
+                        stripe_customer_id: customer_id,
+                    },
+                    timestamp: Some(timestamp),
+                },
+            )
+            .await?;
+        }
+
+        if event.output_tokens > 0 {
+            StripeMeterEvent::create(
+                &self.client,
+                StripeCreateMeterEventParams {
+                    identifier: &format!("output_tokens/{}", event.idempotency_key),
+                    event_name: &model.output_tokens_price.meter_event_name,
+                    payload: StripeCreateMeterEventPayload {
+                        value: event.output_tokens as u64,
+                        stripe_customer_id: customer_id,
+                    },
+                    timestamp: Some(timestamp),
+                },
+            )
+            .await?;
+        }
+
+        Ok(())
+    }
+
+    pub async fn checkout(
+        &self,
+        customer_id: stripe::CustomerId,
+        github_login: &str,
+        model: &StripeModel,
+        success_url: &str,
+    ) -> Result<String> {
+        let mut params = stripe::CreateCheckoutSession::new();
+        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+        params.customer = Some(customer_id);
+        params.client_reference_id = Some(github_login);
+        params.line_items = Some(
+            [
+                &model.input_tokens_price.id,
+                &model.input_cache_creation_tokens_price.id,
+                &model.input_cache_read_tokens_price.id,
+                &model.output_tokens_price.id,
+            ]
+            .into_iter()
+            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
+                price: Some(price_id.to_string()),
+                ..Default::default()
+            })
+            .collect(),
+        );
+        params.success_url = Some(success_url);
+
+        let session = stripe::CheckoutSession::create(&self.client, params).await?;
+        Ok(session.url.context("no checkout session URL")?)
+    }
+}
+
+#[derive(Serialize)]
+struct DefaultAggregation {
+    formula: &'static str,
+}
+
+#[derive(Serialize)]
+struct StripeCreateMeterParams<'a> {
+    default_aggregation: DefaultAggregation,
+    display_name: String,
+    event_name: &'a str,
+}
+
+#[derive(Clone, Deserialize)]
+struct StripeMeter {
+    id: String,
+    event_name: String,
+}
+
+impl StripeMeter {
+    pub fn create(
+        client: &stripe::Client,
+        params: StripeCreateMeterParams,
+    ) -> stripe::Response<Self> {
+        client.post_form("/billing/meters", params)
+    }
+
+    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
+        #[derive(Serialize)]
+        struct Params {}
+
+        client.get_query("/billing/meters", Params {})
+    }
+}
+
+#[derive(Deserialize)]
+struct StripeMeterEvent {
+    identifier: String,
+}
+
+impl StripeMeterEvent {
+    pub async fn create(
+        client: &stripe::Client,
+        params: StripeCreateMeterEventParams<'_>,
+    ) -> Result<Self, stripe::StripeError> {
+        let identifier = params.identifier;
+        match client.post_form("/billing/meter_events", params).await {
+            Ok(event) => Ok(event),
+            Err(stripe::StripeError::Stripe(error)) => {
+                if error.http_status == 400
+                    && error
+                        .message
+                        .as_ref()
+                        .map_or(false, |message| message.contains(identifier))
+                {
+                    Ok(Self {
+                        identifier: identifier.to_string(),
+                    })
+                } else {
+                    Err(stripe::StripeError::Stripe(error))
+                }
+            }
+            Err(error) => Err(error),
+        }
+    }
+}
+
+#[derive(Serialize)]
+struct StripeCreateMeterEventParams<'a> {
+    identifier: &'a str,
+    event_name: &'a str,
+    payload: StripeCreateMeterEventPayload<'a>,
+    timestamp: Option<i64>,
+}
+
+#[derive(Serialize)]
+struct StripeCreateMeterEventPayload<'a> {
+    value: u64,
+    stripe_customer_id: &'a stripe::CustomerId,
+}
+
+fn subscription_contains_price(
+    subscription: &stripe::Subscription,
+    price_id: &stripe::PriceId,
+) -> bool {
+    subscription.items.data.iter().any(|item| {
+        item.price
+            .as_ref()
+            .map_or(false, |price| price.id == *price_id)
+    })
+}

crates/collab/src/tests/test_server.rs 🔗

@@ -635,6 +635,7 @@ impl TestServer {
     ) -> Arc<AppState> {
         Arc::new(AppState {
             db: test_db.db().clone(),
+            llm_db: None,
             live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
             blob_store_client: None,
             stripe_client: None,
@@ -677,8 +678,6 @@ impl TestServer {
                 migrations_path: None,
                 seed_path: None,
                 stripe_api_key: None,
-                stripe_llm_access_price_id: None,
-                stripe_llm_usage_price_id: None,
                 supermaven_admin_api_key: None,
                 user_backfiller_github_access_token: None,
             },