stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::{llm, Cents, Result};
  4use anyhow::Context;
  5use chrono::Utc;
  6use collections::HashMap;
  7use serde::{Deserialize, Serialize};
  8use tokio::sync::RwLock;
  9
 10pub struct StripeBilling {
 11    meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
 12    price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
 13    client: Arc<stripe::Client>,
 14}
 15
 16pub struct StripeModel {
 17    input_tokens_price: StripeBillingPrice,
 18    input_cache_creation_tokens_price: StripeBillingPrice,
 19    input_cache_read_tokens_price: StripeBillingPrice,
 20    output_tokens_price: StripeBillingPrice,
 21}
 22
 23struct StripeBillingPrice {
 24    id: stripe::PriceId,
 25    meter_event_name: String,
 26}
 27
 28impl StripeBilling {
 29    pub fn new(client: Arc<stripe::Client>) -> Self {
 30        Self {
 31            client,
 32            meters_by_event_name: RwLock::new(HashMap::default()),
 33            price_ids_by_meter_id: RwLock::new(HashMap::default()),
 34        }
 35    }
 36
 37    pub async fn initialize(&self) -> Result<()> {
 38        log::info!("initializing StripeBilling");
 39
 40        {
 41            let meters = StripeMeter::list(&self.client).await?.data;
 42            let mut meters_by_event_name = self.meters_by_event_name.write().await;
 43            for meter in meters {
 44                meters_by_event_name.insert(meter.event_name.clone(), meter);
 45            }
 46        }
 47
 48        {
 49            let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
 50                .await?
 51                .data;
 52            let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
 53            for price in prices {
 54                if let Some(recurring) = price.recurring {
 55                    if let Some(meter) = recurring.meter {
 56                        price_ids_by_meter_id.insert(meter, price.id);
 57                    }
 58                }
 59            }
 60        }
 61
 62        Ok(())
 63    }
 64
 65    pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
 66        let input_tokens_price = self
 67            .get_or_insert_price(
 68                &format!("model_{}/input_tokens", model.id),
 69                &format!("{} (Input Tokens)", model.name),
 70                Cents::new(model.price_per_million_input_tokens as u32),
 71            )
 72            .await?;
 73        let input_cache_creation_tokens_price = self
 74            .get_or_insert_price(
 75                &format!("model_{}/input_cache_creation_tokens", model.id),
 76                &format!("{} (Input Cache Creation Tokens)", model.name),
 77                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
 78            )
 79            .await?;
 80        let input_cache_read_tokens_price = self
 81            .get_or_insert_price(
 82                &format!("model_{}/input_cache_read_tokens", model.id),
 83                &format!("{} (Input Cache Read Tokens)", model.name),
 84                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
 85            )
 86            .await?;
 87        let output_tokens_price = self
 88            .get_or_insert_price(
 89                &format!("model_{}/output_tokens", model.id),
 90                &format!("{} (Output Tokens)", model.name),
 91                Cents::new(model.price_per_million_output_tokens as u32),
 92            )
 93            .await?;
 94        Ok(StripeModel {
 95            input_tokens_price,
 96            input_cache_creation_tokens_price,
 97            input_cache_read_tokens_price,
 98            output_tokens_price,
 99        })
100    }
101
102    async fn get_or_insert_price(
103        &self,
104        meter_event_name: &str,
105        price_description: &str,
106        price_per_million_tokens: Cents,
107    ) -> Result<StripeBillingPrice> {
108        let meter =
109            if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
110                meter.clone()
111            } else {
112                let meter = StripeMeter::create(
113                    &self.client,
114                    StripeCreateMeterParams {
115                        default_aggregation: DefaultAggregation { formula: "sum" },
116                        display_name: price_description.to_string(),
117                        event_name: meter_event_name,
118                    },
119                )
120                .await?;
121                self.meters_by_event_name
122                    .write()
123                    .await
124                    .insert(meter_event_name.to_string(), meter.clone());
125                meter
126            };
127
128        let price_id =
129            if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
130                price_id.clone()
131            } else {
132                let price = stripe::Price::create(
133                    &self.client,
134                    stripe::CreatePrice {
135                        active: Some(true),
136                        billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
137                        currency: stripe::Currency::USD,
138                        currency_options: None,
139                        custom_unit_amount: None,
140                        expand: &[],
141                        lookup_key: None,
142                        metadata: None,
143                        nickname: None,
144                        product: None,
145                        product_data: Some(stripe::CreatePriceProductData {
146                            id: None,
147                            active: Some(true),
148                            metadata: None,
149                            name: price_description.to_string(),
150                            statement_descriptor: None,
151                            tax_code: None,
152                            unit_label: None,
153                        }),
154                        recurring: Some(stripe::CreatePriceRecurring {
155                            aggregate_usage: None,
156                            interval: stripe::CreatePriceRecurringInterval::Month,
157                            interval_count: None,
158                            trial_period_days: None,
159                            usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
160                            meter: Some(meter.id.clone()),
161                        }),
162                        tax_behavior: None,
163                        tiers: None,
164                        tiers_mode: None,
165                        transfer_lookup_key: None,
166                        transform_quantity: None,
167                        unit_amount: None,
168                        unit_amount_decimal: Some(&format!(
169                            "{:.12}",
170                            price_per_million_tokens.0 as f64 / 1_000_000f64
171                        )),
172                    },
173                )
174                .await?;
175                self.price_ids_by_meter_id
176                    .write()
177                    .await
178                    .insert(meter.id, price.id.clone());
179                price.id
180            };
181
182        Ok(StripeBillingPrice {
183            id: price_id,
184            meter_event_name: meter_event_name.to_string(),
185        })
186    }
187
188    pub async fn subscribe_to_model(
189        &self,
190        subscription_id: &stripe::SubscriptionId,
191        model: &StripeModel,
192    ) -> Result<()> {
193        let subscription =
194            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
195
196        let mut items = Vec::new();
197
198        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
199            items.push(stripe::UpdateSubscriptionItems {
200                price: Some(model.input_tokens_price.id.to_string()),
201                ..Default::default()
202            });
203        }
204
205        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
206        {
207            items.push(stripe::UpdateSubscriptionItems {
208                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
209                ..Default::default()
210            });
211        }
212
213        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
214            items.push(stripe::UpdateSubscriptionItems {
215                price: Some(model.input_cache_read_tokens_price.id.to_string()),
216                ..Default::default()
217            });
218        }
219
220        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
221            items.push(stripe::UpdateSubscriptionItems {
222                price: Some(model.output_tokens_price.id.to_string()),
223                ..Default::default()
224            });
225        }
226
227        if !items.is_empty() {
228            items.extend(subscription.items.data.iter().map(|item| {
229                stripe::UpdateSubscriptionItems {
230                    id: Some(item.id.to_string()),
231                    ..Default::default()
232                }
233            }));
234
235            stripe::Subscription::update(
236                &self.client,
237                subscription_id,
238                stripe::UpdateSubscription {
239                    items: Some(items),
240                    ..Default::default()
241                },
242            )
243            .await?;
244        }
245
246        Ok(())
247    }
248
249    pub async fn bill_model_usage(
250        &self,
251        customer_id: &stripe::CustomerId,
252        model: &StripeModel,
253        event: &llm::db::billing_event::Model,
254    ) -> Result<()> {
255        let timestamp = Utc::now().timestamp();
256
257        if event.input_tokens > 0 {
258            StripeMeterEvent::create(
259                &self.client,
260                StripeCreateMeterEventParams {
261                    identifier: &format!("input_tokens/{}", event.idempotency_key),
262                    event_name: &model.input_tokens_price.meter_event_name,
263                    payload: StripeCreateMeterEventPayload {
264                        value: event.input_tokens as u64,
265                        stripe_customer_id: customer_id,
266                    },
267                    timestamp: Some(timestamp),
268                },
269            )
270            .await?;
271        }
272
273        if event.input_cache_creation_tokens > 0 {
274            StripeMeterEvent::create(
275                &self.client,
276                StripeCreateMeterEventParams {
277                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
278                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
279                    payload: StripeCreateMeterEventPayload {
280                        value: event.input_cache_creation_tokens as u64,
281                        stripe_customer_id: customer_id,
282                    },
283                    timestamp: Some(timestamp),
284                },
285            )
286            .await?;
287        }
288
289        if event.input_cache_read_tokens > 0 {
290            StripeMeterEvent::create(
291                &self.client,
292                StripeCreateMeterEventParams {
293                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
294                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
295                    payload: StripeCreateMeterEventPayload {
296                        value: event.input_cache_read_tokens as u64,
297                        stripe_customer_id: customer_id,
298                    },
299                    timestamp: Some(timestamp),
300                },
301            )
302            .await?;
303        }
304
305        if event.output_tokens > 0 {
306            StripeMeterEvent::create(
307                &self.client,
308                StripeCreateMeterEventParams {
309                    identifier: &format!("output_tokens/{}", event.idempotency_key),
310                    event_name: &model.output_tokens_price.meter_event_name,
311                    payload: StripeCreateMeterEventPayload {
312                        value: event.output_tokens as u64,
313                        stripe_customer_id: customer_id,
314                    },
315                    timestamp: Some(timestamp),
316                },
317            )
318            .await?;
319        }
320
321        Ok(())
322    }
323
324    pub async fn checkout(
325        &self,
326        customer_id: stripe::CustomerId,
327        github_login: &str,
328        model: &StripeModel,
329        success_url: &str,
330    ) -> Result<String> {
331        let mut params = stripe::CreateCheckoutSession::new();
332        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
333        params.customer = Some(customer_id);
334        params.client_reference_id = Some(github_login);
335        params.line_items = Some(
336            [
337                &model.input_tokens_price.id,
338                &model.input_cache_creation_tokens_price.id,
339                &model.input_cache_read_tokens_price.id,
340                &model.output_tokens_price.id,
341            ]
342            .into_iter()
343            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
344                price: Some(price_id.to_string()),
345                ..Default::default()
346            })
347            .collect(),
348        );
349        params.success_url = Some(success_url);
350
351        let session = stripe::CheckoutSession::create(&self.client, params).await?;
352        Ok(session.url.context("no checkout session URL")?)
353    }
354}
355
356#[derive(Serialize)]
357struct DefaultAggregation {
358    formula: &'static str,
359}
360
361#[derive(Serialize)]
362struct StripeCreateMeterParams<'a> {
363    default_aggregation: DefaultAggregation,
364    display_name: String,
365    event_name: &'a str,
366}
367
368#[derive(Clone, Deserialize)]
369struct StripeMeter {
370    id: String,
371    event_name: String,
372}
373
374impl StripeMeter {
375    pub fn create(
376        client: &stripe::Client,
377        params: StripeCreateMeterParams,
378    ) -> stripe::Response<Self> {
379        client.post_form("/billing/meters", params)
380    }
381
382    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
383        #[derive(Serialize)]
384        struct Params {}
385
386        client.get_query("/billing/meters", Params {})
387    }
388}
389
390#[derive(Deserialize)]
391struct StripeMeterEvent {
392    identifier: String,
393}
394
395impl StripeMeterEvent {
396    pub async fn create(
397        client: &stripe::Client,
398        params: StripeCreateMeterEventParams<'_>,
399    ) -> Result<Self, stripe::StripeError> {
400        let identifier = params.identifier;
401        match client.post_form("/billing/meter_events", params).await {
402            Ok(event) => Ok(event),
403            Err(stripe::StripeError::Stripe(error)) => {
404                if error.http_status == 400
405                    && error
406                        .message
407                        .as_ref()
408                        .map_or(false, |message| message.contains(identifier))
409                {
410                    Ok(Self {
411                        identifier: identifier.to_string(),
412                    })
413                } else {
414                    Err(stripe::StripeError::Stripe(error))
415                }
416            }
417            Err(error) => Err(error),
418        }
419    }
420}
421
422#[derive(Serialize)]
423struct StripeCreateMeterEventParams<'a> {
424    identifier: &'a str,
425    event_name: &'a str,
426    payload: StripeCreateMeterEventPayload<'a>,
427    timestamp: Option<i64>,
428}
429
430#[derive(Serialize)]
431struct StripeCreateMeterEventPayload<'a> {
432    value: u64,
433    stripe_customer_id: &'a stripe::CustomerId,
434}
435
436fn subscription_contains_price(
437    subscription: &stripe::Subscription,
438    price_id: &stripe::PriceId,
439) -> bool {
440    subscription.items.data.iter().any(|item| {
441        item.price
442            .as_ref()
443            .map_or(false, |price| price.id == *price_id)
444    })
445}