@@ -393,7 +393,9 @@ async fn create_billing_subscription(
zed_llm_client::LanguageModelProvider::Anthropic,
"claude-3-7-sonnet",
)?;
- let stripe_model = stripe_billing.register_model(default_model).await?;
+ let stripe_model = stripe_billing
+ .register_model_for_token_based_usage(default_model)
+ .await?;
stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?
@@ -1303,7 +1305,9 @@ async fn sync_token_usage_with_stripe(
.parse()
.context("failed to parse stripe customer id from db")?;
- let stripe_model = stripe_billing.register_model(&model).await?;
+ let stripe_model = stripe_billing
+ .register_model_for_token_based_usage(&model)
+ .await?;
stripe_billing
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
.await?;
@@ -1315,3 +1319,106 @@ async fn sync_token_usage_with_stripe(
Ok(())
}
+
+const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
+
+pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
+ let Some(stripe_billing) = app.stripe_billing.clone() else {
+ log::warn!("failed to retrieve Stripe billing object");
+ return;
+ };
+ let Some(llm_db) = app.llm_db.clone() else {
+ log::warn!("failed to retrieve LLM database");
+ return;
+ };
+
+ let executor = app.executor.clone();
+ executor.spawn_detached({
+ let executor = executor.clone();
+ async move {
+ loop {
+ sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
+ .await
+ .context("failed to sync LLM request usage to Stripe")
+ .trace_err();
+ executor
+ .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
+ .await;
+ }
+ }
+ });
+}
+
+async fn sync_model_request_usage_with_stripe(
+ app: &Arc<AppState>,
+ llm_db: &Arc<LlmDatabase>,
+ stripe_billing: &Arc<StripeBilling>,
+) -> anyhow::Result<()> {
+ let usage_meters = llm_db
+ .get_current_subscription_usage_meters(Utc::now())
+ .await?;
+ let user_ids = usage_meters
+ .iter()
+ .map(|(_, usage)| usage.user_id)
+ .collect::<HashSet<UserId>>();
+ let billing_subscriptions = app
+ .db
+ .get_active_zed_pro_billing_subscriptions(user_ids)
+ .await?;
+
+ let claude_3_5_sonnet = stripe_billing
+ .find_price_by_lookup_key("claude-3-5-sonnet-requests")
+ .await?;
+ let claude_3_7_sonnet = stripe_billing
+ .find_price_by_lookup_key("claude-3-7-sonnet-requests")
+ .await?;
+
+ for (usage_meter, usage) in usage_meters {
+ maybe!(async {
+ let Some((billing_customer, billing_subscription)) =
+ billing_subscriptions.get(&usage.user_id)
+ else {
+ bail!(
+ "Attempted to sync usage meter for user who is not a Stripe customer: {}",
+ usage.user_id
+ );
+ };
+
+ let stripe_customer_id = billing_customer
+ .stripe_customer_id
+ .parse::<stripe::CustomerId>()
+ .context("failed to parse Stripe customer ID from database")?;
+ let stripe_subscription_id = billing_subscription
+ .stripe_subscription_id
+ .parse::<stripe::SubscriptionId>()
+ .context("failed to parse Stripe subscription ID from database")?;
+
+ let model = llm_db.model_by_id(usage_meter.model_id)?;
+
+ let (price_id, meter_event_name) = match model.name.as_str() {
+ "claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
+ "claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
+ model_name => {
+ bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
+ }
+ };
+
+ stripe_billing
+ .subscribe_to_price(&stripe_subscription_id, price_id)
+ .await?;
+ stripe_billing
+ .bill_model_request_usage(
+ &stripe_customer_id,
+ meter_event_name,
+ usage_meter.requests,
+ )
+ .await?;
+
+ Ok(())
+ })
+ .await
+ .log_err();
+ }
+
+ Ok(())
+}
@@ -0,0 +1,43 @@
+use sea_orm::entity::prelude::*;
+
+use crate::llm::db::ModelId;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "subscription_usage_meters")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: i32,
+ pub subscription_usage_id: i32,
+ pub model_id: ModelId,
+ pub requests: i32,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::subscription_usage::Entity",
+ from = "Column::SubscriptionUsageId",
+ to = "super::subscription_usage::Column::Id"
+ )]
+ SubscriptionUsage,
+ #[sea_orm(
+ belongs_to = "super::model::Entity",
+ from = "Column::ModelId",
+ to = "super::model::Column::Id"
+ )]
+ Model,
+}
+
+impl Related<super::subscription_usage::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::SubscriptionUsage.def()
+ }
+}
+
+impl Related<super::model::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Model.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -1,12 +1,13 @@
use std::sync::Arc;
use crate::{Cents, Result, llm};
-use anyhow::Context as _;
+use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc};
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::PriceId;
use tokio::sync::RwLock;
+use uuid::Uuid;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
@@ -17,9 +18,10 @@ pub struct StripeBilling {
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+ prices_by_lookup_key: HashMap<String, stripe::Price>,
}
-pub struct StripeModel {
+pub struct StripeModelTokenPrices {
input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice,
input_cache_read_tokens_price: StripeBillingPrice,
@@ -62,6 +64,10 @@ impl StripeBilling {
}
for price in prices.data {
+ if let Some(lookup_key) = price.lookup_key.clone() {
+ state.prices_by_lookup_key.insert(lookup_key, price.clone());
+ }
+
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
state.price_ids_by_meter_id.insert(meter, price.id);
@@ -74,36 +80,49 @@ impl StripeBilling {
Ok(())
}
- pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
+ pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
+ self.state
+ .read()
+ .await
+ .prices_by_lookup_key
+ .get(lookup_key)
+ .cloned()
+ .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
+ }
+
+ pub async fn register_model_for_token_based_usage(
+ &self,
+ model: &llm::db::model::Model,
+ ) -> Result<StripeModelTokenPrices> {
let input_tokens_price = self
- .get_or_insert_price(
+ .get_or_insert_token_price(
&format!("model_{}/input_tokens", model.id),
&format!("{} (Input Tokens)", model.name),
Cents::new(model.price_per_million_input_tokens as u32),
)
.await?;
let input_cache_creation_tokens_price = self
- .get_or_insert_price(
+ .get_or_insert_token_price(
&format!("model_{}/input_cache_creation_tokens", model.id),
&format!("{} (Input Cache Creation Tokens)", model.name),
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
)
.await?;
let input_cache_read_tokens_price = self
- .get_or_insert_price(
+ .get_or_insert_token_price(
&format!("model_{}/input_cache_read_tokens", model.id),
&format!("{} (Input Cache Read Tokens)", model.name),
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
)
.await?;
let output_tokens_price = self
- .get_or_insert_price(
+ .get_or_insert_token_price(
&format!("model_{}/output_tokens", model.id),
&format!("{} (Output Tokens)", model.name),
Cents::new(model.price_per_million_output_tokens as u32),
)
.await?;
- Ok(StripeModel {
+ Ok(StripeModelTokenPrices {
input_tokens_price,
input_cache_creation_tokens_price,
input_cache_read_tokens_price,
@@ -111,7 +130,7 @@ impl StripeBilling {
})
}
- async fn get_or_insert_price(
+ async fn get_or_insert_token_price(
&self,
meter_event_name: &str,
price_description: &str,
@@ -207,10 +226,43 @@ impl StripeBilling {
})
}
+ pub async fn subscribe_to_price(
+ &self,
+ subscription_id: &stripe::SubscriptionId,
+ price_id: &stripe::PriceId,
+ ) -> Result<()> {
+ let subscription =
+ stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+
+ if subscription_contains_price(&subscription, price_id) {
+ return Ok(());
+ }
+
+ stripe::Subscription::update(
+ &self.client,
+ subscription_id,
+ stripe::UpdateSubscription {
+ items: Some(vec![stripe::UpdateSubscriptionItems {
+ price: Some(price_id.to_string()),
+ ..Default::default()
+ }]),
+ trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
+ end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
+ missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+ },
+ }),
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ Ok(())
+ }
+
pub async fn subscribe_to_model(
&self,
subscription_id: &stripe::SubscriptionId,
- model: &StripeModel,
+ model: &StripeModelTokenPrices,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
@@ -271,7 +323,7 @@ impl StripeBilling {
pub async fn bill_model_token_usage(
&self,
customer_id: &stripe::CustomerId,
- model: &StripeModel,
+ model: &StripeModelTokenPrices,
event: &llm::db::billing_event::Model,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
@@ -343,11 +395,37 @@ impl StripeBilling {
Ok(())
}
+ pub async fn bill_model_request_usage(
+ &self,
+ customer_id: &stripe::CustomerId,
+ event_name: &str,
+ requests: i32,
+ ) -> Result<()> {
+ let timestamp = Utc::now().timestamp();
+ let idempotency_key = Uuid::new_v4();
+
+ StripeMeterEvent::create(
+ &self.client,
+ StripeCreateMeterEventParams {
+ identifier: &format!("model_requests/{}", idempotency_key),
+ event_name,
+ payload: StripeCreateMeterEventPayload {
+ value: requests as u64,
+ stripe_customer_id: customer_id,
+ },
+ timestamp: Some(timestamp),
+ },
+ )
+ .await?;
+
+ Ok(())
+ }
+
pub async fn checkout(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
- model: &StripeModel,
+ model: &StripeModelTokenPrices,
success_url: &str,
) -> Result<String> {
let first_of_next_month = Utc::now()