Prevent deadlock when create a new meter/price on Stripe (#19196)

Antonio Scandurra created

This also puts the entire state of `StripeBilling` behind a `RwLock`.
When fetching the existing prices and meters, or when inserting new
ones, we acquire a write lock and hold it until the Stripe request
completes. This prevents two concurrent calls to `get_or_insert_price`
from inserting the same data twice.

Creating a new meter/price is unusual, so in practice we'll acquire a
read lock most of the time.

/cc @rtfeldman @maxdeviant 

Release Notes:

- N/A

Change summary

crates/collab/src/stripe_billing.rs | 199 ++++++++++++++++--------------
1 file changed, 107 insertions(+), 92 deletions(-)

Detailed changes

crates/collab/src/stripe_billing.rs 🔗

@@ -8,11 +8,16 @@ use serde::{Deserialize, Serialize};
 use tokio::sync::RwLock;
 
 pub struct StripeBilling {
-    meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
-    price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
+    state: RwLock<StripeBillingState>,
     client: Arc<stripe::Client>,
 }
 
+#[derive(Default)]
+struct StripeBillingState {
+    meters_by_event_name: HashMap<String, StripeMeter>,
+    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+}
+
 pub struct StripeModel {
     input_tokens_price: StripeBillingPrice,
     input_cache_creation_tokens_price: StripeBillingPrice,
@@ -29,36 +34,36 @@ impl StripeBilling {
     pub fn new(client: Arc<stripe::Client>) -> Self {
         Self {
             client,
-            meters_by_event_name: RwLock::new(HashMap::default()),
-            price_ids_by_meter_id: RwLock::new(HashMap::default()),
+            state: RwLock::default(),
         }
     }
 
     pub async fn initialize(&self) -> Result<()> {
-        log::info!("initializing StripeBilling");
+        log::info!("StripeBilling: initializing");
 
-        {
-            let meters = StripeMeter::list(&self.client).await?.data;
-            let mut meters_by_event_name = self.meters_by_event_name.write().await;
-            for meter in meters {
-                meters_by_event_name.insert(meter.event_name.clone(), meter);
-            }
+        let mut state = self.state.write().await;
+
+        let (meters, prices) = futures::try_join!(
+            StripeMeter::list(&self.client),
+            stripe::Price::list(&self.client, &stripe::ListPrices::default())
+        )?;
+
+        for meter in meters.data {
+            state
+                .meters_by_event_name
+                .insert(meter.event_name.clone(), meter);
         }
 
-        {
-            let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
-                .await?
-                .data;
-            let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
-            for price in prices {
-                if let Some(recurring) = price.recurring {
-                    if let Some(meter) = recurring.meter {
-                        price_ids_by_meter_id.insert(meter, price.id);
-                    }
+        for price in prices.data {
+            if let Some(recurring) = price.recurring {
+                if let Some(meter) = recurring.meter {
+                    state.price_ids_by_meter_id.insert(meter, price.id);
                 }
             }
         }
 
+        log::info!("StripeBilling: initialized");
+
         Ok(())
     }
 
@@ -105,79 +110,89 @@ impl StripeBilling {
         price_description: &str,
         price_per_million_tokens: Cents,
     ) -> Result<StripeBillingPrice> {
-        let meter =
-            if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
-                meter.clone()
-            } else {
-                let meter = StripeMeter::create(
-                    &self.client,
-                    StripeCreateMeterParams {
-                        default_aggregation: DefaultAggregation { formula: "sum" },
-                        display_name: price_description.to_string(),
-                        event_name: meter_event_name,
-                    },
-                )
-                .await?;
-                self.meters_by_event_name
-                    .write()
-                    .await
-                    .insert(meter_event_name.to_string(), meter.clone());
-                meter
-            };
-
-        let price_id =
-            if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
-                price_id.clone()
-            } else {
-                let price = stripe::Price::create(
-                    &self.client,
-                    stripe::CreatePrice {
+        // Fast code path when the meter and the price already exist.
+        {
+            let state = self.state.read().await;
+            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
+                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
+                    return Ok(StripeBillingPrice {
+                        id: price_id.clone(),
+                        meter_event_name: meter_event_name.to_string(),
+                    });
+                }
+            }
+        }
+
+        let mut state = self.state.write().await;
+        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
+            meter.clone()
+        } else {
+            let meter = StripeMeter::create(
+                &self.client,
+                StripeCreateMeterParams {
+                    default_aggregation: DefaultAggregation { formula: "sum" },
+                    display_name: price_description.to_string(),
+                    event_name: meter_event_name,
+                },
+            )
+            .await?;
+            state
+                .meters_by_event_name
+                .insert(meter_event_name.to_string(), meter.clone());
+            meter
+        };
+
+        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
+            price_id.clone()
+        } else {
+            let price = stripe::Price::create(
+                &self.client,
+                stripe::CreatePrice {
+                    active: Some(true),
+                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
+                    currency: stripe::Currency::USD,
+                    currency_options: None,
+                    custom_unit_amount: None,
+                    expand: &[],
+                    lookup_key: None,
+                    metadata: None,
+                    nickname: None,
+                    product: None,
+                    product_data: Some(stripe::CreatePriceProductData {
+                        id: None,
                         active: Some(true),
-                        billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
-                        currency: stripe::Currency::USD,
-                        currency_options: None,
-                        custom_unit_amount: None,
-                        expand: &[],
-                        lookup_key: None,
                         metadata: None,
-                        nickname: None,
-                        product: None,
-                        product_data: Some(stripe::CreatePriceProductData {
-                            id: None,
-                            active: Some(true),
-                            metadata: None,
-                            name: price_description.to_string(),
-                            statement_descriptor: None,
-                            tax_code: None,
-                            unit_label: None,
-                        }),
-                        recurring: Some(stripe::CreatePriceRecurring {
-                            aggregate_usage: None,
-                            interval: stripe::CreatePriceRecurringInterval::Month,
-                            interval_count: None,
-                            trial_period_days: None,
-                            usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
-                            meter: Some(meter.id.clone()),
-                        }),
-                        tax_behavior: None,
-                        tiers: None,
-                        tiers_mode: None,
-                        transfer_lookup_key: None,
-                        transform_quantity: None,
-                        unit_amount: None,
-                        unit_amount_decimal: Some(&format!(
-                            "{:.12}",
-                            price_per_million_tokens.0 as f64 / 1_000_000f64
-                        )),
-                    },
-                )
-                .await?;
-                self.price_ids_by_meter_id
-                    .write()
-                    .await
-                    .insert(meter.id, price.id.clone());
-                price.id
-            };
+                        name: price_description.to_string(),
+                        statement_descriptor: None,
+                        tax_code: None,
+                        unit_label: None,
+                    }),
+                    recurring: Some(stripe::CreatePriceRecurring {
+                        aggregate_usage: None,
+                        interval: stripe::CreatePriceRecurringInterval::Month,
+                        interval_count: None,
+                        trial_period_days: None,
+                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
+                        meter: Some(meter.id.clone()),
+                    }),
+                    tax_behavior: None,
+                    tiers: None,
+                    tiers_mode: None,
+                    transfer_lookup_key: None,
+                    transform_quantity: None,
+                    unit_amount: None,
+                    unit_amount_decimal: Some(&format!(
+                        "{:.12}",
+                        price_per_million_tokens.0 as f64 / 1_000_000f64
+                    )),
+                },
+            )
+            .await?;
+            state
+                .price_ids_by_meter_id
+                .insert(meter.id, price.id.clone());
+            price.id
+        };
 
         Ok(StripeBillingPrice {
             id: price_id,