collab: Sync model overages for all active Zed Pro subscriptions (#34452)

Marshall Bowers created

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      | 145 ++++++++----
crates/collab/src/db/queries/billing_subscriptions.rs |  27 ++
2 files changed, 117 insertions(+), 55 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -5,7 +5,7 @@ use axum::{
     routing::{get, post},
 };
 use chrono::{DateTime, SecondsFormat, Utc};
-use collections::HashSet;
+use collections::{HashMap, HashSet};
 use reqwest::StatusCode;
 use sea_orm::ActiveValue;
 use serde::{Deserialize, Serialize};
@@ -21,12 +21,13 @@ use stripe::{
     PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
 };
 use util::{ResultExt, maybe};
+use zed_llm_client::LanguageModelProvider;
 
 use crate::api::events::SnowflakeRow;
 use crate::db::billing_subscription::{
     StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
 };
-use crate::llm::db::subscription_usage_meter::CompletionMode;
+use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
 use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
 use crate::rpc::{ResultExt as _, Server};
 use crate::stripe_client::{
@@ -1416,18 +1417,21 @@ async fn sync_model_request_usage_with_stripe(
     let usage_meters = llm_db
         .get_current_subscription_usage_meters(Utc::now())
         .await?;
-    let usage_meters = usage_meters
-        .into_iter()
-        .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
-        .collect::<Vec<_>>();
-    let user_ids = usage_meters
-        .iter()
-        .map(|(_, usage)| usage.user_id)
-        .collect::<HashSet<UserId>>();
-    let billing_subscriptions = app
-        .db
-        .get_active_zed_pro_billing_subscriptions(user_ids)
-        .await?;
+    let mut usage_meters_by_user_id =
+        HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
+    for (usage_meter, usage) in usage_meters {
+        let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
+        meters.push(usage_meter);
+    }
+
+    log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
+    let get_zed_pro_subscriptions_started_at = Utc::now();
+    let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
+    log::info!(
+        "Stripe usage sync: Retrieved {} Zed pro subscriptions in {}",
+        billing_subscriptions.len(),
+        Utc::now() - get_zed_pro_subscriptions_started_at
+    );
 
     let claude_sonnet_4 = stripe_billing
         .find_price_by_lookup_key("claude-sonnet-4-requests")
@@ -1451,59 +1455,90 @@ async fn sync_model_request_usage_with_stripe(
         .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
         .await?;
 
-    let usage_meter_count = usage_meters.len();
+    let model_mode_combinations = [
+        ("claude-opus-4", CompletionMode::Max),
+        ("claude-opus-4", CompletionMode::Normal),
+        ("claude-sonnet-4", CompletionMode::Max),
+        ("claude-sonnet-4", CompletionMode::Normal),
+        ("claude-3-7-sonnet", CompletionMode::Max),
+        ("claude-3-7-sonnet", CompletionMode::Normal),
+        ("claude-3-5-sonnet", CompletionMode::Normal),
+    ];
 
-    log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
+    let billing_subscription_count = billing_subscriptions.len();
 
-    for (usage_meter, usage) in usage_meters {
+    log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
+
+    for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
         maybe!(async {
-            let Some((billing_customer, billing_subscription)) =
-                billing_subscriptions.get(&usage.user_id)
-            else {
-                bail!(
-                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
-                    usage.user_id
-                );
-            };
+            if staff_user_ids.contains(&user_id) {
+                return anyhow::Ok(());
+            }
 
             let stripe_customer_id =
                 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
             let stripe_subscription_id =
                 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
 
-            let model = llm_db.model_by_id(usage_meter.model_id)?;
+            let usage_meters = usage_meters_by_user_id.get(&user_id);
 
-            let (price, meter_event_name) = match model.name.as_str() {
-                "claude-opus-4" => match usage_meter.mode {
-                    CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
-                    CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
-                },
-                "claude-sonnet-4" => match usage_meter.mode {
-                    CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
-                    CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
-                },
-                "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
-                "claude-3-7-sonnet" => match usage_meter.mode {
-                    CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
-                    CompletionMode::Max => {
-                        (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
+            for (model, mode) in &model_mode_combinations {
+                let Ok(model) =
+                    llm_db.model(LanguageModelProvider::Anthropic, model)
+                else {
+                    log::warn!("Failed to load model for user {user_id}: {model}");
+                    continue;
+                };
+
+                let (price, meter_event_name) = match model.name.as_str() {
+                    "claude-opus-4" => match mode {
+                        CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
+                        CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
+                    },
+                    "claude-sonnet-4" => match mode {
+                        CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
+                        CompletionMode::Max => {
+                            (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
+                        }
+                    },
+                    "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
+                    "claude-3-7-sonnet" => match mode {
+                        CompletionMode::Normal => {
+                            (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
+                        }
+                        CompletionMode::Max => {
+                            (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
+                        }
+                    },
+                    model_name => {
+                        bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
                     }
-                },
-                model_name => {
-                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
+                };
+
+                let model_requests = usage_meters
+                    .and_then(|usage_meters| {
+                        usage_meters
+                            .iter()
+                            .find(|meter| meter.model_id == model.id && meter.mode == *mode)
+                    })
+                    .map(|usage_meter| usage_meter.requests)
+                    .unwrap_or(0);
+
+                if model_requests > 0 {
+                    stripe_billing
+                        .subscribe_to_price(&stripe_subscription_id, price)
+                        .await?;
                 }
-            };
 
-            stripe_billing
-                .subscribe_to_price(&stripe_subscription_id, price)
-                .await?;
-            stripe_billing
-                .bill_model_request_usage(
-                    &stripe_customer_id,
-                    meter_event_name,
-                    usage_meter.requests,
-                )
-                .await?;
+                stripe_billing
+                    .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
+                    .await
+                    .with_context(|| {
+                        format!(
+                            "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
+                        )
+                    })?;
+            }
 
             Ok(())
         })
@@ -1512,7 +1547,7 @@ async fn sync_model_request_usage_with_stripe(
     }
 
     log::info!(
-        "Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
+        "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
         Utc::now() - started_at
     );
 

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -199,6 +199,33 @@ impl Database {
 
     pub async fn get_active_zed_pro_billing_subscriptions(
         &self,
+    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
+        self.transaction(|tx| async move {
+            let mut rows = billing_subscription::Entity::find()
+                .inner_join(billing_customer::Entity)
+                .select_also(billing_customer::Entity)
+                .filter(
+                    billing_subscription::Column::StripeSubscriptionStatus
+                        .eq(StripeSubscriptionStatus::Active),
+                )
+                .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
+                .order_by_asc(billing_subscription::Column::Id)
+                .stream(&*tx)
+                .await?;
+
+            let mut subscriptions = HashMap::default();
+            while let Some(row) = rows.next().await {
+                if let (subscription, Some(customer)) = row? {
+                    subscriptions.insert(customer.user_id, (customer, subscription));
+                }
+            }
+            Ok(subscriptions)
+        })
+        .await
+    }
+
+    pub async fn get_active_zed_pro_billing_subscriptions_for_users(
+        &self,
         user_ids: HashSet<UserId>,
     ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
         self.transaction(|tx| {