@@ -210,6 +210,13 @@ async fn create_billing_subscription(
"not supported".into(),
))?
};
+ let Some(stripe_billing) = app.stripe_billing.clone() else {
+ log::error!("failed to retrieve Stripe billing object");
+ Err(Error::http(
+ StatusCode::NOT_IMPLEMENTED,
+ "not supported".into(),
+ ))?
+ };
let Some(llm_db) = app.llm_db.clone() else {
log::error!("failed to retrieve LLM database");
Err(Error::http(
@@ -236,7 +243,6 @@ async fn create_billing_subscription(
};
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
- let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
let checkout_session_url = stripe_billing
@@ -716,8 +722,8 @@ async fn find_or_create_billing_customer(
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
- let Some(stripe_client) = app.stripe_client.clone() else {
- log::warn!("failed to retrieve Stripe client");
+ let Some(stripe_billing) = app.stripe_billing.clone() else {
+ log::warn!("failed to retrieve Stripe billing object");
return;
};
let Some(llm_db) = app.llm_db.clone() else {
@@ -730,7 +736,7 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let executor = executor.clone();
async move {
loop {
- sync_with_stripe(&app, &llm_db, &stripe_client)
+ sync_with_stripe(&app, &llm_db, &stripe_billing)
.await
.trace_err();
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
@@ -742,10 +748,8 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
async fn sync_with_stripe(
app: &Arc<AppState>,
llm_db: &Arc<LlmDatabase>,
- stripe_client: &Arc<stripe::Client>,
+ stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
- let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
-
let events = llm_db.get_billing_events().await?;
let user_ids = events
.iter()
@@ -31,6 +31,8 @@ use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
+use crate::stripe_billing::StripeBilling;
+
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
@@ -274,6 +276,7 @@ pub struct AppState {
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>,
+ pub stripe_billing: Option<Arc<StripeBilling>>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub clickhouse_client: Option<::clickhouse::Client>,
@@ -317,12 +320,16 @@ impl AppState {
};
let db = Arc::new(db);
+ let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
let this = Self {
db: db.clone(),
llm_db,
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
- stripe_client: build_stripe_client(&config).map(Arc::new).log_err(),
+ stripe_billing: stripe_client
+ .clone()
+ .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
+ stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
clickhouse_client: config
@@ -5,10 +5,11 @@ use anyhow::Context;
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
+use tokio::sync::RwLock;
pub struct StripeBilling {
- meters_by_event_name: HashMap<String, StripeMeter>,
- price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+ meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
+ price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
client: Arc<stripe::Client>,
}
@@ -25,32 +26,43 @@ struct StripeBillingPrice {
}
impl StripeBilling {
- pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
- let mut meters_by_event_name = HashMap::default();
- for meter in StripeMeter::list(&client).await?.data {
- meters_by_event_name.insert(meter.event_name.clone(), meter);
+ pub fn new(client: Arc<stripe::Client>) -> Self {
+ Self {
+ client,
+ meters_by_event_name: RwLock::new(HashMap::default()),
+ price_ids_by_meter_id: RwLock::new(HashMap::default()),
}
+ }
+
+ pub async fn initialize(&self) -> Result<()> {
+ log::info!("initializing StripeBilling");
- let mut price_ids_by_meter_id = HashMap::default();
- for price in stripe::Price::list(&client, &stripe::ListPrices::default())
- .await?
- .data
{
- if let Some(recurring) = price.recurring {
- if let Some(meter) = recurring.meter {
- price_ids_by_meter_id.insert(meter, price.id);
+ let meters = StripeMeter::list(&self.client).await?.data;
+ let mut meters_by_event_name = self.meters_by_event_name.write().await;
+ for meter in meters {
+ meters_by_event_name.insert(meter.event_name.clone(), meter);
+ }
+ }
+
+ {
+ let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
+ .await?
+ .data;
+ let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
+ for price in prices {
+ if let Some(recurring) = price.recurring {
+ if let Some(meter) = recurring.meter {
+ price_ids_by_meter_id.insert(meter, price.id);
+ }
}
}
}
- Ok(Self {
- meters_by_event_name,
- price_ids_by_meter_id,
- client,
- })
+ Ok(())
}
- pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
+ pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
let input_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_tokens", model.id),
@@ -88,78 +100,84 @@ impl StripeBilling {
}
async fn get_or_insert_price(
- &mut self,
+ &self,
meter_event_name: &str,
price_description: &str,
price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> {
- let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
- meter.clone()
- } else {
- let meter = StripeMeter::create(
- &self.client,
- StripeCreateMeterParams {
- default_aggregation: DefaultAggregation { formula: "sum" },
- display_name: price_description.to_string(),
- event_name: meter_event_name,
- },
- )
- .await?;
- self.meters_by_event_name
- .insert(meter_event_name.to_string(), meter.clone());
- meter
- };
-
- let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
- price_id.clone()
- } else {
- let price = stripe::Price::create(
- &self.client,
- stripe::CreatePrice {
- active: Some(true),
- billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
- currency: stripe::Currency::USD,
- currency_options: None,
- custom_unit_amount: None,
- expand: &[],
- lookup_key: None,
- metadata: None,
- nickname: None,
- product: None,
- product_data: Some(stripe::CreatePriceProductData {
- id: None,
+ let meter =
+ if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
+ meter.clone()
+ } else {
+ let meter = StripeMeter::create(
+ &self.client,
+ StripeCreateMeterParams {
+ default_aggregation: DefaultAggregation { formula: "sum" },
+ display_name: price_description.to_string(),
+ event_name: meter_event_name,
+ },
+ )
+ .await?;
+ self.meters_by_event_name
+ .write()
+ .await
+ .insert(meter_event_name.to_string(), meter.clone());
+ meter
+ };
+
+ let price_id =
+ if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
+ price_id.clone()
+ } else {
+ let price = stripe::Price::create(
+ &self.client,
+ stripe::CreatePrice {
active: Some(true),
+ billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
+ currency: stripe::Currency::USD,
+ currency_options: None,
+ custom_unit_amount: None,
+ expand: &[],
+ lookup_key: None,
metadata: None,
- name: price_description.to_string(),
- statement_descriptor: None,
- tax_code: None,
- unit_label: None,
- }),
- recurring: Some(stripe::CreatePriceRecurring {
- aggregate_usage: None,
- interval: stripe::CreatePriceRecurringInterval::Month,
- interval_count: None,
- trial_period_days: None,
- usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
- meter: Some(meter.id.clone()),
- }),
- tax_behavior: None,
- tiers: None,
- tiers_mode: None,
- transfer_lookup_key: None,
- transform_quantity: None,
- unit_amount: None,
- unit_amount_decimal: Some(&format!(
- "{:.12}",
- price_per_million_tokens.0 as f64 / 1_000_000f64
- )),
- },
- )
- .await?;
- self.price_ids_by_meter_id
- .insert(meter.id, price.id.clone());
- price.id
- };
+ nickname: None,
+ product: None,
+ product_data: Some(stripe::CreatePriceProductData {
+ id: None,
+ active: Some(true),
+ metadata: None,
+ name: price_description.to_string(),
+ statement_descriptor: None,
+ tax_code: None,
+ unit_label: None,
+ }),
+ recurring: Some(stripe::CreatePriceRecurring {
+ aggregate_usage: None,
+ interval: stripe::CreatePriceRecurringInterval::Month,
+ interval_count: None,
+ trial_period_days: None,
+ usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
+ meter: Some(meter.id.clone()),
+ }),
+ tax_behavior: None,
+ tiers: None,
+ tiers_mode: None,
+ transfer_lookup_key: None,
+ transform_quantity: None,
+ unit_amount: None,
+ unit_amount_decimal: Some(&format!(
+ "{:.12}",
+ price_per_million_tokens.0 as f64 / 1_000_000f64
+ )),
+ },
+ )
+ .await?;
+ self.price_ids_by_meter_id
+ .write()
+ .await
+ .insert(meter.id, price.id.clone());
+ price.id
+ };
Ok(StripeBillingPrice {
id: price_id,