stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::{Cents, Result, llm};
  4use anyhow::Context as _;
  5use chrono::{Datelike, Utc};
  6use collections::HashMap;
  7use serde::{Deserialize, Serialize};
  8use tokio::sync::RwLock;
  9
 10pub struct StripeBilling {
 11    state: RwLock<StripeBillingState>,
 12    client: Arc<stripe::Client>,
 13}
 14
 15#[derive(Default)]
 16struct StripeBillingState {
 17    meters_by_event_name: HashMap<String, StripeMeter>,
 18    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 19}
 20
 21pub struct StripeModel {
 22    input_tokens_price: StripeBillingPrice,
 23    input_cache_creation_tokens_price: StripeBillingPrice,
 24    input_cache_read_tokens_price: StripeBillingPrice,
 25    output_tokens_price: StripeBillingPrice,
 26}
 27
 28struct StripeBillingPrice {
 29    id: stripe::PriceId,
 30    meter_event_name: String,
 31}
 32
 33impl StripeBilling {
 34    pub fn new(client: Arc<stripe::Client>) -> Self {
 35        Self {
 36            client,
 37            state: RwLock::default(),
 38        }
 39    }
 40
 41    pub async fn initialize(&self) -> Result<()> {
 42        log::info!("StripeBilling: initializing");
 43
 44        let mut state = self.state.write().await;
 45
 46        let (meters, prices) = futures::try_join!(
 47            StripeMeter::list(&self.client),
 48            stripe::Price::list(
 49                &self.client,
 50                &stripe::ListPrices {
 51                    limit: Some(100),
 52                    ..Default::default()
 53                }
 54            )
 55        )?;
 56
 57        for meter in meters.data {
 58            state
 59                .meters_by_event_name
 60                .insert(meter.event_name.clone(), meter);
 61        }
 62
 63        for price in prices.data {
 64            if let Some(recurring) = price.recurring {
 65                if let Some(meter) = recurring.meter {
 66                    state.price_ids_by_meter_id.insert(meter, price.id);
 67                }
 68            }
 69        }
 70
 71        log::info!("StripeBilling: initialized");
 72
 73        Ok(())
 74    }
 75
 76    pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
 77        let input_tokens_price = self
 78            .get_or_insert_price(
 79                &format!("model_{}/input_tokens", model.id),
 80                &format!("{} (Input Tokens)", model.name),
 81                Cents::new(model.price_per_million_input_tokens as u32),
 82            )
 83            .await?;
 84        let input_cache_creation_tokens_price = self
 85            .get_or_insert_price(
 86                &format!("model_{}/input_cache_creation_tokens", model.id),
 87                &format!("{} (Input Cache Creation Tokens)", model.name),
 88                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
 89            )
 90            .await?;
 91        let input_cache_read_tokens_price = self
 92            .get_or_insert_price(
 93                &format!("model_{}/input_cache_read_tokens", model.id),
 94                &format!("{} (Input Cache Read Tokens)", model.name),
 95                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
 96            )
 97            .await?;
 98        let output_tokens_price = self
 99            .get_or_insert_price(
100                &format!("model_{}/output_tokens", model.id),
101                &format!("{} (Output Tokens)", model.name),
102                Cents::new(model.price_per_million_output_tokens as u32),
103            )
104            .await?;
105        Ok(StripeModel {
106            input_tokens_price,
107            input_cache_creation_tokens_price,
108            input_cache_read_tokens_price,
109            output_tokens_price,
110        })
111    }
112
113    async fn get_or_insert_price(
114        &self,
115        meter_event_name: &str,
116        price_description: &str,
117        price_per_million_tokens: Cents,
118    ) -> Result<StripeBillingPrice> {
119        // Fast code path when the meter and the price already exist.
120        {
121            let state = self.state.read().await;
122            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
123                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
124                    return Ok(StripeBillingPrice {
125                        id: price_id.clone(),
126                        meter_event_name: meter_event_name.to_string(),
127                    });
128                }
129            }
130        }
131
132        let mut state = self.state.write().await;
133        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
134            meter.clone()
135        } else {
136            let meter = StripeMeter::create(
137                &self.client,
138                StripeCreateMeterParams {
139                    default_aggregation: DefaultAggregation { formula: "sum" },
140                    display_name: price_description.to_string(),
141                    event_name: meter_event_name,
142                },
143            )
144            .await?;
145            state
146                .meters_by_event_name
147                .insert(meter_event_name.to_string(), meter.clone());
148            meter
149        };
150
151        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
152            price_id.clone()
153        } else {
154            let price = stripe::Price::create(
155                &self.client,
156                stripe::CreatePrice {
157                    active: Some(true),
158                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
159                    currency: stripe::Currency::USD,
160                    currency_options: None,
161                    custom_unit_amount: None,
162                    expand: &[],
163                    lookup_key: None,
164                    metadata: None,
165                    nickname: None,
166                    product: None,
167                    product_data: Some(stripe::CreatePriceProductData {
168                        id: None,
169                        active: Some(true),
170                        metadata: None,
171                        name: price_description.to_string(),
172                        statement_descriptor: None,
173                        tax_code: None,
174                        unit_label: None,
175                    }),
176                    recurring: Some(stripe::CreatePriceRecurring {
177                        aggregate_usage: None,
178                        interval: stripe::CreatePriceRecurringInterval::Month,
179                        interval_count: None,
180                        trial_period_days: None,
181                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
182                        meter: Some(meter.id.clone()),
183                    }),
184                    tax_behavior: None,
185                    tiers: None,
186                    tiers_mode: None,
187                    transfer_lookup_key: None,
188                    transform_quantity: None,
189                    unit_amount: None,
190                    unit_amount_decimal: Some(&format!(
191                        "{:.12}",
192                        price_per_million_tokens.0 as f64 / 1_000_000f64
193                    )),
194                },
195            )
196            .await?;
197            state
198                .price_ids_by_meter_id
199                .insert(meter.id, price.id.clone());
200            price.id
201        };
202
203        Ok(StripeBillingPrice {
204            id: price_id,
205            meter_event_name: meter_event_name.to_string(),
206        })
207    }
208
209    pub async fn subscribe_to_model(
210        &self,
211        subscription_id: &stripe::SubscriptionId,
212        model: &StripeModel,
213    ) -> Result<()> {
214        let subscription =
215            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
216
217        let mut items = Vec::new();
218
219        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
220            items.push(stripe::UpdateSubscriptionItems {
221                price: Some(model.input_tokens_price.id.to_string()),
222                ..Default::default()
223            });
224        }
225
226        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
227        {
228            items.push(stripe::UpdateSubscriptionItems {
229                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
230                ..Default::default()
231            });
232        }
233
234        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
235            items.push(stripe::UpdateSubscriptionItems {
236                price: Some(model.input_cache_read_tokens_price.id.to_string()),
237                ..Default::default()
238            });
239        }
240
241        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
242            items.push(stripe::UpdateSubscriptionItems {
243                price: Some(model.output_tokens_price.id.to_string()),
244                ..Default::default()
245            });
246        }
247
248        if !items.is_empty() {
249            items.extend(subscription.items.data.iter().map(|item| {
250                stripe::UpdateSubscriptionItems {
251                    id: Some(item.id.to_string()),
252                    ..Default::default()
253                }
254            }));
255
256            stripe::Subscription::update(
257                &self.client,
258                subscription_id,
259                stripe::UpdateSubscription {
260                    items: Some(items),
261                    ..Default::default()
262                },
263            )
264            .await?;
265        }
266
267        Ok(())
268    }
269
270    pub async fn bill_model_usage(
271        &self,
272        customer_id: &stripe::CustomerId,
273        model: &StripeModel,
274        event: &llm::db::billing_event::Model,
275    ) -> Result<()> {
276        let timestamp = Utc::now().timestamp();
277
278        if event.input_tokens > 0 {
279            StripeMeterEvent::create(
280                &self.client,
281                StripeCreateMeterEventParams {
282                    identifier: &format!("input_tokens/{}", event.idempotency_key),
283                    event_name: &model.input_tokens_price.meter_event_name,
284                    payload: StripeCreateMeterEventPayload {
285                        value: event.input_tokens as u64,
286                        stripe_customer_id: customer_id,
287                    },
288                    timestamp: Some(timestamp),
289                },
290            )
291            .await?;
292        }
293
294        if event.input_cache_creation_tokens > 0 {
295            StripeMeterEvent::create(
296                &self.client,
297                StripeCreateMeterEventParams {
298                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
299                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
300                    payload: StripeCreateMeterEventPayload {
301                        value: event.input_cache_creation_tokens as u64,
302                        stripe_customer_id: customer_id,
303                    },
304                    timestamp: Some(timestamp),
305                },
306            )
307            .await?;
308        }
309
310        if event.input_cache_read_tokens > 0 {
311            StripeMeterEvent::create(
312                &self.client,
313                StripeCreateMeterEventParams {
314                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
315                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
316                    payload: StripeCreateMeterEventPayload {
317                        value: event.input_cache_read_tokens as u64,
318                        stripe_customer_id: customer_id,
319                    },
320                    timestamp: Some(timestamp),
321                },
322            )
323            .await?;
324        }
325
326        if event.output_tokens > 0 {
327            StripeMeterEvent::create(
328                &self.client,
329                StripeCreateMeterEventParams {
330                    identifier: &format!("output_tokens/{}", event.idempotency_key),
331                    event_name: &model.output_tokens_price.meter_event_name,
332                    payload: StripeCreateMeterEventPayload {
333                        value: event.output_tokens as u64,
334                        stripe_customer_id: customer_id,
335                    },
336                    timestamp: Some(timestamp),
337                },
338            )
339            .await?;
340        }
341
342        Ok(())
343    }
344
345    pub async fn checkout(
346        &self,
347        customer_id: stripe::CustomerId,
348        github_login: &str,
349        model: &StripeModel,
350        success_url: &str,
351    ) -> Result<String> {
352        let first_of_next_month = Utc::now()
353            .checked_add_months(chrono::Months::new(1))
354            .unwrap()
355            .with_day(1)
356            .unwrap();
357
358        let mut params = stripe::CreateCheckoutSession::new();
359        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
360        params.customer = Some(customer_id);
361        params.client_reference_id = Some(github_login);
362        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
363            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
364            ..Default::default()
365        });
366        params.line_items = Some(
367            [
368                &model.input_tokens_price.id,
369                &model.input_cache_creation_tokens_price.id,
370                &model.input_cache_read_tokens_price.id,
371                &model.output_tokens_price.id,
372            ]
373            .into_iter()
374            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
375                price: Some(price_id.to_string()),
376                ..Default::default()
377            })
378            .collect(),
379        );
380        params.success_url = Some(success_url);
381
382        let session = stripe::CheckoutSession::create(&self.client, params).await?;
383        Ok(session.url.context("no checkout session URL")?)
384    }
385}
386
387#[derive(Serialize)]
388struct DefaultAggregation {
389    formula: &'static str,
390}
391
392#[derive(Serialize)]
393struct StripeCreateMeterParams<'a> {
394    default_aggregation: DefaultAggregation,
395    display_name: String,
396    event_name: &'a str,
397}
398
399#[derive(Clone, Deserialize)]
400struct StripeMeter {
401    id: String,
402    event_name: String,
403}
404
405impl StripeMeter {
406    pub fn create(
407        client: &stripe::Client,
408        params: StripeCreateMeterParams,
409    ) -> stripe::Response<Self> {
410        client.post_form("/billing/meters", params)
411    }
412
413    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
414        #[derive(Serialize)]
415        struct Params {
416            #[serde(skip_serializing_if = "Option::is_none")]
417            limit: Option<u64>,
418        }
419
420        client.get_query("/billing/meters", Params { limit: Some(100) })
421    }
422}
423
424#[derive(Deserialize)]
425struct StripeMeterEvent {
426    identifier: String,
427}
428
429impl StripeMeterEvent {
430    pub async fn create(
431        client: &stripe::Client,
432        params: StripeCreateMeterEventParams<'_>,
433    ) -> Result<Self, stripe::StripeError> {
434        let identifier = params.identifier;
435        match client.post_form("/billing/meter_events", params).await {
436            Ok(event) => Ok(event),
437            Err(stripe::StripeError::Stripe(error)) => {
438                if error.http_status == 400
439                    && error
440                        .message
441                        .as_ref()
442                        .map_or(false, |message| message.contains(identifier))
443                {
444                    Ok(Self {
445                        identifier: identifier.to_string(),
446                    })
447                } else {
448                    Err(stripe::StripeError::Stripe(error))
449                }
450            }
451            Err(error) => Err(error),
452        }
453    }
454}
455
456#[derive(Serialize)]
457struct StripeCreateMeterEventParams<'a> {
458    identifier: &'a str,
459    event_name: &'a str,
460    payload: StripeCreateMeterEventPayload<'a>,
461    timestamp: Option<i64>,
462}
463
464#[derive(Serialize)]
465struct StripeCreateMeterEventPayload<'a> {
466    value: u64,
467    stripe_customer_id: &'a stripe::CustomerId,
468}
469
470fn subscription_contains_price(
471    subscription: &stripe::Subscription,
472    price_id: &stripe::PriceId,
473) -> bool {
474    subscription.items.data.iter().any(|item| {
475        item.price
476            .as_ref()
477            .map_or(false, |price| price.id == *price_id)
478    })
479}