collab: Make the `StripeBilling` object long-lived (#19090)

Marshall Bowers and Richard created

This PR makes the `StripeBilling` object long-lived so that we can make
better use of the cached data on it.

We now hold it on the `AppState` and spawn a background task to
initialize the cache on startup.

Release Notes:

- N/A

Co-authored-by: Richard <richard@zed.dev>

Change summary

crates/collab/src/api/billing.rs       |  18 +-
crates/collab/src/lib.rs               |   9 +
crates/collab/src/main.rs              |   7 +
crates/collab/src/stripe_billing.rs    | 188 +++++++++++++++------------
crates/collab/src/tests/test_server.rs |   1 
5 files changed, 130 insertions(+), 93 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -210,6 +210,13 @@ async fn create_billing_subscription(
             "not supported".into(),
         ))?
     };
+    let Some(stripe_billing) = app.stripe_billing.clone() else {
+        log::error!("failed to retrieve Stripe billing object");
+        Err(Error::http(
+            StatusCode::NOT_IMPLEMENTED,
+            "not supported".into(),
+        ))?
+    };
     let Some(llm_db) = app.llm_db.clone() else {
         log::error!("failed to retrieve LLM database");
         Err(Error::http(
@@ -236,7 +243,6 @@ async fn create_billing_subscription(
         };
 
     let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
-    let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
     let stripe_model = stripe_billing.register_model(default_model).await?;
     let success_url = format!("{}/account", app.config.zed_dot_dev_url());
     let checkout_session_url = stripe_billing
@@ -716,8 +722,8 @@ async fn find_or_create_billing_customer(
 const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
 
 pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
-    let Some(stripe_client) = app.stripe_client.clone() else {
-        log::warn!("failed to retrieve Stripe client");
+    let Some(stripe_billing) = app.stripe_billing.clone() else {
+        log::warn!("failed to retrieve Stripe billing object");
         return;
     };
     let Some(llm_db) = app.llm_db.clone() else {
@@ -730,7 +736,7 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
         let executor = executor.clone();
         async move {
             loop {
-                sync_with_stripe(&app, &llm_db, &stripe_client)
+                sync_with_stripe(&app, &llm_db, &stripe_billing)
                     .await
                     .trace_err();
                 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
@@ -742,10 +748,8 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
 async fn sync_with_stripe(
     app: &Arc<AppState>,
     llm_db: &Arc<LlmDatabase>,
-    stripe_client: &Arc<stripe::Client>,
+    stripe_billing: &Arc<StripeBilling>,
 ) -> anyhow::Result<()> {
-    let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
-
     let events = llm_db.get_billing_events().await?;
     let user_ids = events
         .iter()

crates/collab/src/lib.rs 🔗

@@ -31,6 +31,8 @@ use serde::Deserialize;
 use std::{path::PathBuf, sync::Arc};
 use util::ResultExt;
 
+use crate::stripe_billing::StripeBilling;
+
 pub type Result<T, E = Error> = std::result::Result<T, E>;
 
 pub enum Error {
@@ -274,6 +276,7 @@ pub struct AppState {
     pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
     pub blob_store_client: Option<aws_sdk_s3::Client>,
     pub stripe_client: Option<Arc<stripe::Client>>,
+    pub stripe_billing: Option<Arc<StripeBilling>>,
     pub rate_limiter: Arc<RateLimiter>,
     pub executor: Executor,
     pub clickhouse_client: Option<::clickhouse::Client>,
@@ -317,12 +320,16 @@ impl AppState {
         };
 
         let db = Arc::new(db);
+        let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
         let this = Self {
             db: db.clone(),
             llm_db,
             live_kit_client,
             blob_store_client: build_blob_store_client(&config).await.log_err(),
-            stripe_client: build_stripe_client(&config).map(Arc::new).log_err(),
+            stripe_billing: stripe_client
+                .clone()
+                .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
+            stripe_client,
             rate_limiter: Arc::new(RateLimiter::new(db)),
             executor,
             clickhouse_client: config

crates/collab/src/main.rs 🔗

@@ -111,6 +111,13 @@ async fn main() -> Result<()> {
 
                 let state = AppState::new(config, Executor::Production).await?;
 
+                if let Some(stripe_billing) = state.stripe_billing.clone() {
+                    let executor = state.executor.clone();
+                    executor.spawn_detached(async move {
+                        stripe_billing.initialize().await.trace_err();
+                    });
+                }
+
                 if mode.is_collab() {
                     state.db.purge_old_embeddings().await.trace_err();
                     RateLimiter::save_periodically(

crates/collab/src/stripe_billing.rs 🔗

@@ -5,10 +5,11 @@ use anyhow::Context;
 use chrono::Utc;
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
+use tokio::sync::RwLock;
 
 pub struct StripeBilling {
-    meters_by_event_name: HashMap<String, StripeMeter>,
-    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+    meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
+    price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
     client: Arc<stripe::Client>,
 }
 
@@ -25,32 +26,43 @@ struct StripeBillingPrice {
 }
 
 impl StripeBilling {
-    pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
-        let mut meters_by_event_name = HashMap::default();
-        for meter in StripeMeter::list(&client).await?.data {
-            meters_by_event_name.insert(meter.event_name.clone(), meter);
+    pub fn new(client: Arc<stripe::Client>) -> Self {
+        Self {
+            client,
+            meters_by_event_name: RwLock::new(HashMap::default()),
+            price_ids_by_meter_id: RwLock::new(HashMap::default()),
         }
+    }
+
+    pub async fn initialize(&self) -> Result<()> {
+        log::info!("initializing StripeBilling");
 
-        let mut price_ids_by_meter_id = HashMap::default();
-        for price in stripe::Price::list(&client, &stripe::ListPrices::default())
-            .await?
-            .data
         {
-            if let Some(recurring) = price.recurring {
-                if let Some(meter) = recurring.meter {
-                    price_ids_by_meter_id.insert(meter, price.id);
+            let meters = StripeMeter::list(&self.client).await?.data;
+            let mut meters_by_event_name = self.meters_by_event_name.write().await;
+            for meter in meters {
+                meters_by_event_name.insert(meter.event_name.clone(), meter);
+            }
+        }
+
+        {
+            let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
+                .await?
+                .data;
+            let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
+            for price in prices {
+                if let Some(recurring) = price.recurring {
+                    if let Some(meter) = recurring.meter {
+                        price_ids_by_meter_id.insert(meter, price.id);
+                    }
                 }
             }
         }
 
-        Ok(Self {
-            meters_by_event_name,
-            price_ids_by_meter_id,
-            client,
-        })
+        Ok(())
     }
 
-    pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
+    pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
         let input_tokens_price = self
             .get_or_insert_price(
                 &format!("model_{}/input_tokens", model.id),
@@ -88,78 +100,84 @@ impl StripeBilling {
     }
 
     async fn get_or_insert_price(
-        &mut self,
+        &self,
         meter_event_name: &str,
         price_description: &str,
         price_per_million_tokens: Cents,
     ) -> Result<StripeBillingPrice> {
-        let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
-            meter.clone()
-        } else {
-            let meter = StripeMeter::create(
-                &self.client,
-                StripeCreateMeterParams {
-                    default_aggregation: DefaultAggregation { formula: "sum" },
-                    display_name: price_description.to_string(),
-                    event_name: meter_event_name,
-                },
-            )
-            .await?;
-            self.meters_by_event_name
-                .insert(meter_event_name.to_string(), meter.clone());
-            meter
-        };
-
-        let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
-            price_id.clone()
-        } else {
-            let price = stripe::Price::create(
-                &self.client,
-                stripe::CreatePrice {
-                    active: Some(true),
-                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
-                    currency: stripe::Currency::USD,
-                    currency_options: None,
-                    custom_unit_amount: None,
-                    expand: &[],
-                    lookup_key: None,
-                    metadata: None,
-                    nickname: None,
-                    product: None,
-                    product_data: Some(stripe::CreatePriceProductData {
-                        id: None,
+        let meter =
+            if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
+                meter.clone()
+            } else {
+                let meter = StripeMeter::create(
+                    &self.client,
+                    StripeCreateMeterParams {
+                        default_aggregation: DefaultAggregation { formula: "sum" },
+                        display_name: price_description.to_string(),
+                        event_name: meter_event_name,
+                    },
+                )
+                .await?;
+                self.meters_by_event_name
+                    .write()
+                    .await
+                    .insert(meter_event_name.to_string(), meter.clone());
+                meter
+            };
+
+        let price_id =
+            if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
+                price_id.clone()
+            } else {
+                let price = stripe::Price::create(
+                    &self.client,
+                    stripe::CreatePrice {
                         active: Some(true),
+                        billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
+                        currency: stripe::Currency::USD,
+                        currency_options: None,
+                        custom_unit_amount: None,
+                        expand: &[],
+                        lookup_key: None,
                         metadata: None,
-                        name: price_description.to_string(),
-                        statement_descriptor: None,
-                        tax_code: None,
-                        unit_label: None,
-                    }),
-                    recurring: Some(stripe::CreatePriceRecurring {
-                        aggregate_usage: None,
-                        interval: stripe::CreatePriceRecurringInterval::Month,
-                        interval_count: None,
-                        trial_period_days: None,
-                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
-                        meter: Some(meter.id.clone()),
-                    }),
-                    tax_behavior: None,
-                    tiers: None,
-                    tiers_mode: None,
-                    transfer_lookup_key: None,
-                    transform_quantity: None,
-                    unit_amount: None,
-                    unit_amount_decimal: Some(&format!(
-                        "{:.12}",
-                        price_per_million_tokens.0 as f64 / 1_000_000f64
-                    )),
-                },
-            )
-            .await?;
-            self.price_ids_by_meter_id
-                .insert(meter.id, price.id.clone());
-            price.id
-        };
+                        nickname: None,
+                        product: None,
+                        product_data: Some(stripe::CreatePriceProductData {
+                            id: None,
+                            active: Some(true),
+                            metadata: None,
+                            name: price_description.to_string(),
+                            statement_descriptor: None,
+                            tax_code: None,
+                            unit_label: None,
+                        }),
+                        recurring: Some(stripe::CreatePriceRecurring {
+                            aggregate_usage: None,
+                            interval: stripe::CreatePriceRecurringInterval::Month,
+                            interval_count: None,
+                            trial_period_days: None,
+                            usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
+                            meter: Some(meter.id.clone()),
+                        }),
+                        tax_behavior: None,
+                        tiers: None,
+                        tiers_mode: None,
+                        transfer_lookup_key: None,
+                        transform_quantity: None,
+                        unit_amount: None,
+                        unit_amount_decimal: Some(&format!(
+                            "{:.12}",
+                            price_per_million_tokens.0 as f64 / 1_000_000f64
+                        )),
+                    },
+                )
+                .await?;
+                self.price_ids_by_meter_id
+                    .write()
+                    .await
+                    .insert(meter.id, price.id.clone());
+                price.id
+            };
 
         Ok(StripeBillingPrice {
             id: price_id,

crates/collab/src/tests/test_server.rs 🔗

@@ -639,6 +639,7 @@ impl TestServer {
             live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
             blob_store_client: None,
             stripe_client: None,
+            stripe_billing: None,
             rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
             executor,
             clickhouse_client: None,