stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::{llm, Cents, Result};
  4use anyhow::Context;
  5use chrono::Utc;
  6use collections::HashMap;
  7use serde::{Deserialize, Serialize};
  8use tokio::sync::RwLock;
  9
 10pub struct StripeBilling {
 11    state: RwLock<StripeBillingState>,
 12    client: Arc<stripe::Client>,
 13}
 14
 15#[derive(Default)]
 16struct StripeBillingState {
 17    meters_by_event_name: HashMap<String, StripeMeter>,
 18    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 19}
 20
 21pub struct StripeModel {
 22    input_tokens_price: StripeBillingPrice,
 23    input_cache_creation_tokens_price: StripeBillingPrice,
 24    input_cache_read_tokens_price: StripeBillingPrice,
 25    output_tokens_price: StripeBillingPrice,
 26}
 27
 28struct StripeBillingPrice {
 29    id: stripe::PriceId,
 30    meter_event_name: String,
 31}
 32
 33impl StripeBilling {
 34    pub fn new(client: Arc<stripe::Client>) -> Self {
 35        Self {
 36            client,
 37            state: RwLock::default(),
 38        }
 39    }
 40
 41    pub async fn initialize(&self) -> Result<()> {
 42        log::info!("StripeBilling: initializing");
 43
 44        let mut state = self.state.write().await;
 45
 46        let (meters, prices) = futures::try_join!(
 47            StripeMeter::list(&self.client),
 48            stripe::Price::list(
 49                &self.client,
 50                &stripe::ListPrices {
 51                    limit: Some(100),
 52                    ..Default::default()
 53                }
 54            )
 55        )?;
 56
 57        for meter in meters.data {
 58            state
 59                .meters_by_event_name
 60                .insert(meter.event_name.clone(), meter);
 61        }
 62
 63        for price in prices.data {
 64            if let Some(recurring) = price.recurring {
 65                if let Some(meter) = recurring.meter {
 66                    state.price_ids_by_meter_id.insert(meter, price.id);
 67                }
 68            }
 69        }
 70
 71        log::info!("StripeBilling: initialized");
 72
 73        Ok(())
 74    }
 75
 76    pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
 77        let input_tokens_price = self
 78            .get_or_insert_price(
 79                &format!("model_{}/input_tokens", model.id),
 80                &format!("{} (Input Tokens)", model.name),
 81                Cents::new(model.price_per_million_input_tokens as u32),
 82            )
 83            .await?;
 84        let input_cache_creation_tokens_price = self
 85            .get_or_insert_price(
 86                &format!("model_{}/input_cache_creation_tokens", model.id),
 87                &format!("{} (Input Cache Creation Tokens)", model.name),
 88                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
 89            )
 90            .await?;
 91        let input_cache_read_tokens_price = self
 92            .get_or_insert_price(
 93                &format!("model_{}/input_cache_read_tokens", model.id),
 94                &format!("{} (Input Cache Read Tokens)", model.name),
 95                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
 96            )
 97            .await?;
 98        let output_tokens_price = self
 99            .get_or_insert_price(
100                &format!("model_{}/output_tokens", model.id),
101                &format!("{} (Output Tokens)", model.name),
102                Cents::new(model.price_per_million_output_tokens as u32),
103            )
104            .await?;
105        Ok(StripeModel {
106            input_tokens_price,
107            input_cache_creation_tokens_price,
108            input_cache_read_tokens_price,
109            output_tokens_price,
110        })
111    }
112
113    async fn get_or_insert_price(
114        &self,
115        meter_event_name: &str,
116        price_description: &str,
117        price_per_million_tokens: Cents,
118    ) -> Result<StripeBillingPrice> {
119        // Fast code path when the meter and the price already exist.
120        {
121            let state = self.state.read().await;
122            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
123                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
124                    return Ok(StripeBillingPrice {
125                        id: price_id.clone(),
126                        meter_event_name: meter_event_name.to_string(),
127                    });
128                }
129            }
130        }
131
132        let mut state = self.state.write().await;
133        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
134            meter.clone()
135        } else {
136            let meter = StripeMeter::create(
137                &self.client,
138                StripeCreateMeterParams {
139                    default_aggregation: DefaultAggregation { formula: "sum" },
140                    display_name: price_description.to_string(),
141                    event_name: meter_event_name,
142                },
143            )
144            .await?;
145            state
146                .meters_by_event_name
147                .insert(meter_event_name.to_string(), meter.clone());
148            meter
149        };
150
151        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
152            price_id.clone()
153        } else {
154            let price = stripe::Price::create(
155                &self.client,
156                stripe::CreatePrice {
157                    active: Some(true),
158                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
159                    currency: stripe::Currency::USD,
160                    currency_options: None,
161                    custom_unit_amount: None,
162                    expand: &[],
163                    lookup_key: None,
164                    metadata: None,
165                    nickname: None,
166                    product: None,
167                    product_data: Some(stripe::CreatePriceProductData {
168                        id: None,
169                        active: Some(true),
170                        metadata: None,
171                        name: price_description.to_string(),
172                        statement_descriptor: None,
173                        tax_code: None,
174                        unit_label: None,
175                    }),
176                    recurring: Some(stripe::CreatePriceRecurring {
177                        aggregate_usage: None,
178                        interval: stripe::CreatePriceRecurringInterval::Month,
179                        interval_count: None,
180                        trial_period_days: None,
181                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
182                        meter: Some(meter.id.clone()),
183                    }),
184                    tax_behavior: None,
185                    tiers: None,
186                    tiers_mode: None,
187                    transfer_lookup_key: None,
188                    transform_quantity: None,
189                    unit_amount: None,
190                    unit_amount_decimal: Some(&format!(
191                        "{:.12}",
192                        price_per_million_tokens.0 as f64 / 1_000_000f64
193                    )),
194                },
195            )
196            .await?;
197            state
198                .price_ids_by_meter_id
199                .insert(meter.id, price.id.clone());
200            price.id
201        };
202
203        Ok(StripeBillingPrice {
204            id: price_id,
205            meter_event_name: meter_event_name.to_string(),
206        })
207    }
208
209    pub async fn subscribe_to_model(
210        &self,
211        subscription_id: &stripe::SubscriptionId,
212        model: &StripeModel,
213    ) -> Result<()> {
214        let subscription =
215            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
216
217        let mut items = Vec::new();
218
219        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
220            items.push(stripe::UpdateSubscriptionItems {
221                price: Some(model.input_tokens_price.id.to_string()),
222                ..Default::default()
223            });
224        }
225
226        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
227        {
228            items.push(stripe::UpdateSubscriptionItems {
229                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
230                ..Default::default()
231            });
232        }
233
234        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
235            items.push(stripe::UpdateSubscriptionItems {
236                price: Some(model.input_cache_read_tokens_price.id.to_string()),
237                ..Default::default()
238            });
239        }
240
241        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
242            items.push(stripe::UpdateSubscriptionItems {
243                price: Some(model.output_tokens_price.id.to_string()),
244                ..Default::default()
245            });
246        }
247
248        if !items.is_empty() {
249            items.extend(subscription.items.data.iter().map(|item| {
250                stripe::UpdateSubscriptionItems {
251                    id: Some(item.id.to_string()),
252                    ..Default::default()
253                }
254            }));
255
256            stripe::Subscription::update(
257                &self.client,
258                subscription_id,
259                stripe::UpdateSubscription {
260                    items: Some(items),
261                    ..Default::default()
262                },
263            )
264            .await?;
265        }
266
267        Ok(())
268    }
269
270    pub async fn bill_model_usage(
271        &self,
272        customer_id: &stripe::CustomerId,
273        model: &StripeModel,
274        event: &llm::db::billing_event::Model,
275    ) -> Result<()> {
276        let timestamp = Utc::now().timestamp();
277
278        if event.input_tokens > 0 {
279            StripeMeterEvent::create(
280                &self.client,
281                StripeCreateMeterEventParams {
282                    identifier: &format!("input_tokens/{}", event.idempotency_key),
283                    event_name: &model.input_tokens_price.meter_event_name,
284                    payload: StripeCreateMeterEventPayload {
285                        value: event.input_tokens as u64,
286                        stripe_customer_id: customer_id,
287                    },
288                    timestamp: Some(timestamp),
289                },
290            )
291            .await?;
292        }
293
294        if event.input_cache_creation_tokens > 0 {
295            StripeMeterEvent::create(
296                &self.client,
297                StripeCreateMeterEventParams {
298                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
299                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
300                    payload: StripeCreateMeterEventPayload {
301                        value: event.input_cache_creation_tokens as u64,
302                        stripe_customer_id: customer_id,
303                    },
304                    timestamp: Some(timestamp),
305                },
306            )
307            .await?;
308        }
309
310        if event.input_cache_read_tokens > 0 {
311            StripeMeterEvent::create(
312                &self.client,
313                StripeCreateMeterEventParams {
314                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
315                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
316                    payload: StripeCreateMeterEventPayload {
317                        value: event.input_cache_read_tokens as u64,
318                        stripe_customer_id: customer_id,
319                    },
320                    timestamp: Some(timestamp),
321                },
322            )
323            .await?;
324        }
325
326        if event.output_tokens > 0 {
327            StripeMeterEvent::create(
328                &self.client,
329                StripeCreateMeterEventParams {
330                    identifier: &format!("output_tokens/{}", event.idempotency_key),
331                    event_name: &model.output_tokens_price.meter_event_name,
332                    payload: StripeCreateMeterEventPayload {
333                        value: event.output_tokens as u64,
334                        stripe_customer_id: customer_id,
335                    },
336                    timestamp: Some(timestamp),
337                },
338            )
339            .await?;
340        }
341
342        Ok(())
343    }
344
345    pub async fn checkout(
346        &self,
347        customer_id: stripe::CustomerId,
348        github_login: &str,
349        model: &StripeModel,
350        success_url: &str,
351    ) -> Result<String> {
352        let mut params = stripe::CreateCheckoutSession::new();
353        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
354        params.customer = Some(customer_id);
355        params.client_reference_id = Some(github_login);
356        params.line_items = Some(
357            [
358                &model.input_tokens_price.id,
359                &model.input_cache_creation_tokens_price.id,
360                &model.input_cache_read_tokens_price.id,
361                &model.output_tokens_price.id,
362            ]
363            .into_iter()
364            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
365                price: Some(price_id.to_string()),
366                ..Default::default()
367            })
368            .collect(),
369        );
370        params.success_url = Some(success_url);
371
372        let session = stripe::CheckoutSession::create(&self.client, params).await?;
373        Ok(session.url.context("no checkout session URL")?)
374    }
375}
376
377#[derive(Serialize)]
378struct DefaultAggregation {
379    formula: &'static str,
380}
381
382#[derive(Serialize)]
383struct StripeCreateMeterParams<'a> {
384    default_aggregation: DefaultAggregation,
385    display_name: String,
386    event_name: &'a str,
387}
388
389#[derive(Clone, Deserialize)]
390struct StripeMeter {
391    id: String,
392    event_name: String,
393}
394
395impl StripeMeter {
396    pub fn create(
397        client: &stripe::Client,
398        params: StripeCreateMeterParams,
399    ) -> stripe::Response<Self> {
400        client.post_form("/billing/meters", params)
401    }
402
403    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
404        #[derive(Serialize)]
405        struct Params {
406            #[serde(skip_serializing_if = "Option::is_none")]
407            limit: Option<u64>,
408        }
409
410        client.get_query("/billing/meters", Params { limit: Some(100) })
411    }
412}
413
414#[derive(Deserialize)]
415struct StripeMeterEvent {
416    identifier: String,
417}
418
419impl StripeMeterEvent {
420    pub async fn create(
421        client: &stripe::Client,
422        params: StripeCreateMeterEventParams<'_>,
423    ) -> Result<Self, stripe::StripeError> {
424        let identifier = params.identifier;
425        match client.post_form("/billing/meter_events", params).await {
426            Ok(event) => Ok(event),
427            Err(stripe::StripeError::Stripe(error)) => {
428                if error.http_status == 400
429                    && error
430                        .message
431                        .as_ref()
432                        .map_or(false, |message| message.contains(identifier))
433                {
434                    Ok(Self {
435                        identifier: identifier.to_string(),
436                    })
437                } else {
438                    Err(stripe::StripeError::Stripe(error))
439                }
440            }
441            Err(error) => Err(error),
442        }
443    }
444}
445
446#[derive(Serialize)]
447struct StripeCreateMeterEventParams<'a> {
448    identifier: &'a str,
449    event_name: &'a str,
450    payload: StripeCreateMeterEventPayload<'a>,
451    timestamp: Option<i64>,
452}
453
454#[derive(Serialize)]
455struct StripeCreateMeterEventPayload<'a> {
456    value: u64,
457    stripe_customer_id: &'a stripe::CustomerId,
458}
459
460fn subscription_contains_price(
461    subscription: &stripe::Subscription,
462    price_id: &stripe::PriceId,
463) -> bool {
464    subscription.items.data.iter().any(|item| {
465        item.price
466            .as_ref()
467            .map_or(false, |price| price.id == *price_id)
468    })
469}