@@ -8,11 +8,16 @@ use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
pub struct StripeBilling {
- meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
- price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
+ state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>,
}
+#[derive(Default)]
+struct StripeBillingState {
+ meters_by_event_name: HashMap<String, StripeMeter>,
+ price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+}
+
pub struct StripeModel {
input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice,
@@ -29,36 +34,36 @@ impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client,
- meters_by_event_name: RwLock::new(HashMap::default()),
- price_ids_by_meter_id: RwLock::new(HashMap::default()),
+ state: RwLock::default(),
}
}
pub async fn initialize(&self) -> Result<()> {
- log::info!("initializing StripeBilling");
+ log::info!("StripeBilling: initializing");
- {
- let meters = StripeMeter::list(&self.client).await?.data;
- let mut meters_by_event_name = self.meters_by_event_name.write().await;
- for meter in meters {
- meters_by_event_name.insert(meter.event_name.clone(), meter);
- }
+ let mut state = self.state.write().await;
+
+ let (meters, prices) = futures::try_join!(
+ StripeMeter::list(&self.client),
+ stripe::Price::list(&self.client, &stripe::ListPrices::default())
+ )?;
+
+ for meter in meters.data {
+ state
+ .meters_by_event_name
+ .insert(meter.event_name.clone(), meter);
}
- {
- let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
- .await?
- .data;
- let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
- for price in prices {
- if let Some(recurring) = price.recurring {
- if let Some(meter) = recurring.meter {
- price_ids_by_meter_id.insert(meter, price.id);
- }
+ for price in prices.data {
+ if let Some(recurring) = price.recurring {
+ if let Some(meter) = recurring.meter {
+ state.price_ids_by_meter_id.insert(meter, price.id);
}
}
}
+ log::info!("StripeBilling: initialized");
+
Ok(())
}
@@ -105,79 +110,89 @@ impl StripeBilling {
price_description: &str,
price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> {
- let meter =
- if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
- meter.clone()
- } else {
- let meter = StripeMeter::create(
- &self.client,
- StripeCreateMeterParams {
- default_aggregation: DefaultAggregation { formula: "sum" },
- display_name: price_description.to_string(),
- event_name: meter_event_name,
- },
- )
- .await?;
- self.meters_by_event_name
- .write()
- .await
- .insert(meter_event_name.to_string(), meter.clone());
- meter
- };
-
- let price_id =
- if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
- price_id.clone()
- } else {
- let price = stripe::Price::create(
- &self.client,
- stripe::CreatePrice {
+ // Fast code path when the meter and the price already exist.
+ {
+ let state = self.state.read().await;
+ if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
+ if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
+ return Ok(StripeBillingPrice {
+ id: price_id.clone(),
+ meter_event_name: meter_event_name.to_string(),
+ });
+ }
+ }
+ }
+
+ let mut state = self.state.write().await;
+ let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
+ meter.clone()
+ } else {
+ let meter = StripeMeter::create(
+ &self.client,
+ StripeCreateMeterParams {
+ default_aggregation: DefaultAggregation { formula: "sum" },
+ display_name: price_description.to_string(),
+ event_name: meter_event_name,
+ },
+ )
+ .await?;
+ state
+ .meters_by_event_name
+ .insert(meter_event_name.to_string(), meter.clone());
+ meter
+ };
+
+ let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
+ price_id.clone()
+ } else {
+ let price = stripe::Price::create(
+ &self.client,
+ stripe::CreatePrice {
+ active: Some(true),
+ billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
+ currency: stripe::Currency::USD,
+ currency_options: None,
+ custom_unit_amount: None,
+ expand: &[],
+ lookup_key: None,
+ metadata: None,
+ nickname: None,
+ product: None,
+ product_data: Some(stripe::CreatePriceProductData {
+ id: None,
active: Some(true),
- billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
- currency: stripe::Currency::USD,
- currency_options: None,
- custom_unit_amount: None,
- expand: &[],
- lookup_key: None,
metadata: None,
- nickname: None,
- product: None,
- product_data: Some(stripe::CreatePriceProductData {
- id: None,
- active: Some(true),
- metadata: None,
- name: price_description.to_string(),
- statement_descriptor: None,
- tax_code: None,
- unit_label: None,
- }),
- recurring: Some(stripe::CreatePriceRecurring {
- aggregate_usage: None,
- interval: stripe::CreatePriceRecurringInterval::Month,
- interval_count: None,
- trial_period_days: None,
- usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
- meter: Some(meter.id.clone()),
- }),
- tax_behavior: None,
- tiers: None,
- tiers_mode: None,
- transfer_lookup_key: None,
- transform_quantity: None,
- unit_amount: None,
- unit_amount_decimal: Some(&format!(
- "{:.12}",
- price_per_million_tokens.0 as f64 / 1_000_000f64
- )),
- },
- )
- .await?;
- self.price_ids_by_meter_id
- .write()
- .await
- .insert(meter.id, price.id.clone());
- price.id
- };
+ name: price_description.to_string(),
+ statement_descriptor: None,
+ tax_code: None,
+ unit_label: None,
+ }),
+ recurring: Some(stripe::CreatePriceRecurring {
+ aggregate_usage: None,
+ interval: stripe::CreatePriceRecurringInterval::Month,
+ interval_count: None,
+ trial_period_days: None,
+ usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
+ meter: Some(meter.id.clone()),
+ }),
+ tax_behavior: None,
+ tiers: None,
+ tiers_mode: None,
+ transfer_lookup_key: None,
+ transform_quantity: None,
+ unit_amount: None,
+ unit_amount_decimal: Some(&format!(
+ "{:.12}",
+ price_per_million_tokens.0 as f64 / 1_000_000f64
+ )),
+ },
+ )
+ .await?;
+ state
+ .price_ids_by_meter_id
+ .insert(meter.id, price.id.clone());
+ price.id
+ };
Ok(StripeBillingPrice {
id: price_id,