@@ -5,7 +5,7 @@ use axum::{
routing::{get, post},
};
use chrono::{DateTime, SecondsFormat, Utc};
-use collections::HashSet;
+use collections::{HashMap, HashSet};
use reqwest::StatusCode;
use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize};
@@ -21,12 +21,13 @@ use stripe::{
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe};
+use zed_llm_client::LanguageModelProvider;
use crate::api::events::SnowflakeRow;
use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
-use crate::llm::db::subscription_usage_meter::CompletionMode;
+use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{
@@ -1416,18 +1417,21 @@ async fn sync_model_request_usage_with_stripe(
let usage_meters = llm_db
.get_current_subscription_usage_meters(Utc::now())
.await?;
- let usage_meters = usage_meters
- .into_iter()
- .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
- .collect::<Vec<_>>();
- let user_ids = usage_meters
- .iter()
- .map(|(_, usage)| usage.user_id)
- .collect::<HashSet<UserId>>();
- let billing_subscriptions = app
- .db
- .get_active_zed_pro_billing_subscriptions(user_ids)
- .await?;
+ let mut usage_meters_by_user_id =
+ HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
+ for (usage_meter, usage) in usage_meters {
+ let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
+ meters.push(usage_meter);
+ }
+
+ log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
+ let get_zed_pro_subscriptions_started_at = Utc::now();
+ let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
+ log::info!(
+ "Stripe usage sync: Retrieved {} Zed pro subscriptions in {}",
+ billing_subscriptions.len(),
+ Utc::now() - get_zed_pro_subscriptions_started_at
+ );
let claude_sonnet_4 = stripe_billing
.find_price_by_lookup_key("claude-sonnet-4-requests")
@@ -1451,59 +1455,90 @@ async fn sync_model_request_usage_with_stripe(
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?;
- let usage_meter_count = usage_meters.len();
+ let model_mode_combinations = [
+ ("claude-opus-4", CompletionMode::Max),
+ ("claude-opus-4", CompletionMode::Normal),
+ ("claude-sonnet-4", CompletionMode::Max),
+ ("claude-sonnet-4", CompletionMode::Normal),
+ ("claude-3-7-sonnet", CompletionMode::Max),
+ ("claude-3-7-sonnet", CompletionMode::Normal),
+ ("claude-3-5-sonnet", CompletionMode::Normal),
+ ];
- log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
+ let billing_subscription_count = billing_subscriptions.len();
- for (usage_meter, usage) in usage_meters {
+ log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
+
+ for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
maybe!(async {
- let Some((billing_customer, billing_subscription)) =
- billing_subscriptions.get(&usage.user_id)
- else {
- bail!(
- "Attempted to sync usage meter for user who is not a Stripe customer: {}",
- usage.user_id
- );
- };
+ if staff_user_ids.contains(&user_id) {
+ return anyhow::Ok(());
+ }
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription_id =
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
- let model = llm_db.model_by_id(usage_meter.model_id)?;
+ let usage_meters = usage_meters_by_user_id.get(&user_id);
- let (price, meter_event_name) = match model.name.as_str() {
- "claude-opus-4" => match usage_meter.mode {
- CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
- CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
- },
- "claude-sonnet-4" => match usage_meter.mode {
- CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
- CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
- },
- "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
- "claude-3-7-sonnet" => match usage_meter.mode {
- CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
- CompletionMode::Max => {
- (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
+ for (model, mode) in &model_mode_combinations {
+ let Ok(model) =
+ llm_db.model(LanguageModelProvider::Anthropic, model)
+ else {
+ log::warn!("Failed to load model for user {user_id}: {model}");
+ continue;
+ };
+
+ let (price, meter_event_name) = match model.name.as_str() {
+ "claude-opus-4" => match mode {
+ CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
+ CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
+ },
+ "claude-sonnet-4" => match mode {
+ CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
+ CompletionMode::Max => {
+ (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
+ }
+ },
+ "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
+ "claude-3-7-sonnet" => match mode {
+ CompletionMode::Normal => {
+ (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
+ }
+ CompletionMode::Max => {
+ (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
+ }
+ },
+ model_name => {
+ bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
}
- },
- model_name => {
- bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
+ };
+
+ let model_requests = usage_meters
+ .and_then(|usage_meters| {
+ usage_meters
+ .iter()
+ .find(|meter| meter.model_id == model.id && meter.mode == *mode)
+ })
+ .map(|usage_meter| usage_meter.requests)
+ .unwrap_or(0);
+
+ if model_requests > 0 {
+ stripe_billing
+ .subscribe_to_price(&stripe_subscription_id, price)
+ .await?;
}
- };
- stripe_billing
- .subscribe_to_price(&stripe_subscription_id, price)
- .await?;
- stripe_billing
- .bill_model_request_usage(
- &stripe_customer_id,
- meter_event_name,
- usage_meter.requests,
- )
- .await?;
+ stripe_billing
+ .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
+ .await
+ .with_context(|| {
+ format!(
+ "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
+ )
+ })?;
+ }
Ok(())
})
@@ -1512,7 +1547,7 @@ async fn sync_model_request_usage_with_stripe(
}
log::info!(
- "Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
+ "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
Utc::now() - started_at
);
@@ -199,6 +199,33 @@ impl Database {
pub async fn get_active_zed_pro_billing_subscriptions(
&self,
+ ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
+ self.transaction(|tx| async move {
+ let mut rows = billing_subscription::Entity::find()
+ .inner_join(billing_customer::Entity)
+ .select_also(billing_customer::Entity)
+ .filter(
+ billing_subscription::Column::StripeSubscriptionStatus
+ .eq(StripeSubscriptionStatus::Active),
+ )
+ .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
+ .order_by_asc(billing_subscription::Column::Id)
+ .stream(&*tx)
+ .await?;
+
+ let mut subscriptions = HashMap::default();
+ while let Some(row) = rows.next().await {
+ if let (subscription, Some(customer)) = row? {
+ subscriptions.insert(customer.user_id, (customer, subscription));
+ }
+ }
+ Ok(subscriptions)
+ })
+ .await
+ }
+
+ pub async fn get_active_zed_pro_billing_subscriptions_for_users(
+ &self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {