stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::{Cents, Result, llm};
  4use anyhow::Context as _;
  5use chrono::{Datelike, Utc};
  6use collections::HashMap;
  7use serde::{Deserialize, Serialize};
  8use stripe::PriceId;
  9use tokio::sync::RwLock;
 10
 11pub struct StripeBilling {
 12    state: RwLock<StripeBillingState>,
 13    client: Arc<stripe::Client>,
 14}
 15
 16#[derive(Default)]
 17struct StripeBillingState {
 18    meters_by_event_name: HashMap<String, StripeMeter>,
 19    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 20}
 21
 22pub struct StripeModel {
 23    input_tokens_price: StripeBillingPrice,
 24    input_cache_creation_tokens_price: StripeBillingPrice,
 25    input_cache_read_tokens_price: StripeBillingPrice,
 26    output_tokens_price: StripeBillingPrice,
 27}
 28
 29struct StripeBillingPrice {
 30    id: stripe::PriceId,
 31    meter_event_name: String,
 32}
 33
 34impl StripeBilling {
 35    pub fn new(client: Arc<stripe::Client>) -> Self {
 36        Self {
 37            client,
 38            state: RwLock::default(),
 39        }
 40    }
 41
 42    pub async fn initialize(&self) -> Result<()> {
 43        log::info!("StripeBilling: initializing");
 44
 45        let mut state = self.state.write().await;
 46
 47        let (meters, prices) = futures::try_join!(
 48            StripeMeter::list(&self.client),
 49            stripe::Price::list(
 50                &self.client,
 51                &stripe::ListPrices {
 52                    limit: Some(100),
 53                    ..Default::default()
 54                }
 55            )
 56        )?;
 57
 58        for meter in meters.data {
 59            state
 60                .meters_by_event_name
 61                .insert(meter.event_name.clone(), meter);
 62        }
 63
 64        for price in prices.data {
 65            if let Some(recurring) = price.recurring {
 66                if let Some(meter) = recurring.meter {
 67                    state.price_ids_by_meter_id.insert(meter, price.id);
 68                }
 69            }
 70        }
 71
 72        log::info!("StripeBilling: initialized");
 73
 74        Ok(())
 75    }
 76
 77    pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
 78        let input_tokens_price = self
 79            .get_or_insert_price(
 80                &format!("model_{}/input_tokens", model.id),
 81                &format!("{} (Input Tokens)", model.name),
 82                Cents::new(model.price_per_million_input_tokens as u32),
 83            )
 84            .await?;
 85        let input_cache_creation_tokens_price = self
 86            .get_or_insert_price(
 87                &format!("model_{}/input_cache_creation_tokens", model.id),
 88                &format!("{} (Input Cache Creation Tokens)", model.name),
 89                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
 90            )
 91            .await?;
 92        let input_cache_read_tokens_price = self
 93            .get_or_insert_price(
 94                &format!("model_{}/input_cache_read_tokens", model.id),
 95                &format!("{} (Input Cache Read Tokens)", model.name),
 96                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
 97            )
 98            .await?;
 99        let output_tokens_price = self
100            .get_or_insert_price(
101                &format!("model_{}/output_tokens", model.id),
102                &format!("{} (Output Tokens)", model.name),
103                Cents::new(model.price_per_million_output_tokens as u32),
104            )
105            .await?;
106        Ok(StripeModel {
107            input_tokens_price,
108            input_cache_creation_tokens_price,
109            input_cache_read_tokens_price,
110            output_tokens_price,
111        })
112    }
113
114    async fn get_or_insert_price(
115        &self,
116        meter_event_name: &str,
117        price_description: &str,
118        price_per_million_tokens: Cents,
119    ) -> Result<StripeBillingPrice> {
120        // Fast code path when the meter and the price already exist.
121        {
122            let state = self.state.read().await;
123            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
124                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
125                    return Ok(StripeBillingPrice {
126                        id: price_id.clone(),
127                        meter_event_name: meter_event_name.to_string(),
128                    });
129                }
130            }
131        }
132
133        let mut state = self.state.write().await;
134        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
135            meter.clone()
136        } else {
137            let meter = StripeMeter::create(
138                &self.client,
139                StripeCreateMeterParams {
140                    default_aggregation: DefaultAggregation { formula: "sum" },
141                    display_name: price_description.to_string(),
142                    event_name: meter_event_name,
143                },
144            )
145            .await?;
146            state
147                .meters_by_event_name
148                .insert(meter_event_name.to_string(), meter.clone());
149            meter
150        };
151
152        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
153            price_id.clone()
154        } else {
155            let price = stripe::Price::create(
156                &self.client,
157                stripe::CreatePrice {
158                    active: Some(true),
159                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
160                    currency: stripe::Currency::USD,
161                    currency_options: None,
162                    custom_unit_amount: None,
163                    expand: &[],
164                    lookup_key: None,
165                    metadata: None,
166                    nickname: None,
167                    product: None,
168                    product_data: Some(stripe::CreatePriceProductData {
169                        id: None,
170                        active: Some(true),
171                        metadata: None,
172                        name: price_description.to_string(),
173                        statement_descriptor: None,
174                        tax_code: None,
175                        unit_label: None,
176                    }),
177                    recurring: Some(stripe::CreatePriceRecurring {
178                        aggregate_usage: None,
179                        interval: stripe::CreatePriceRecurringInterval::Month,
180                        interval_count: None,
181                        trial_period_days: None,
182                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
183                        meter: Some(meter.id.clone()),
184                    }),
185                    tax_behavior: None,
186                    tiers: None,
187                    tiers_mode: None,
188                    transfer_lookup_key: None,
189                    transform_quantity: None,
190                    unit_amount: None,
191                    unit_amount_decimal: Some(&format!(
192                        "{:.12}",
193                        price_per_million_tokens.0 as f64 / 1_000_000f64
194                    )),
195                },
196            )
197            .await?;
198            state
199                .price_ids_by_meter_id
200                .insert(meter.id, price.id.clone());
201            price.id
202        };
203
204        Ok(StripeBillingPrice {
205            id: price_id,
206            meter_event_name: meter_event_name.to_string(),
207        })
208    }
209
210    pub async fn subscribe_to_model(
211        &self,
212        subscription_id: &stripe::SubscriptionId,
213        model: &StripeModel,
214    ) -> Result<()> {
215        let subscription =
216            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
217
218        let mut items = Vec::new();
219
220        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
221            items.push(stripe::UpdateSubscriptionItems {
222                price: Some(model.input_tokens_price.id.to_string()),
223                ..Default::default()
224            });
225        }
226
227        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
228        {
229            items.push(stripe::UpdateSubscriptionItems {
230                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
231                ..Default::default()
232            });
233        }
234
235        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
236            items.push(stripe::UpdateSubscriptionItems {
237                price: Some(model.input_cache_read_tokens_price.id.to_string()),
238                ..Default::default()
239            });
240        }
241
242        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
243            items.push(stripe::UpdateSubscriptionItems {
244                price: Some(model.output_tokens_price.id.to_string()),
245                ..Default::default()
246            });
247        }
248
249        if !items.is_empty() {
250            items.extend(subscription.items.data.iter().map(|item| {
251                stripe::UpdateSubscriptionItems {
252                    id: Some(item.id.to_string()),
253                    ..Default::default()
254                }
255            }));
256
257            stripe::Subscription::update(
258                &self.client,
259                subscription_id,
260                stripe::UpdateSubscription {
261                    items: Some(items),
262                    ..Default::default()
263                },
264            )
265            .await?;
266        }
267
268        Ok(())
269    }
270
271    pub async fn bill_model_token_usage(
272        &self,
273        customer_id: &stripe::CustomerId,
274        model: &StripeModel,
275        event: &llm::db::billing_event::Model,
276    ) -> Result<()> {
277        let timestamp = Utc::now().timestamp();
278
279        if event.input_tokens > 0 {
280            StripeMeterEvent::create(
281                &self.client,
282                StripeCreateMeterEventParams {
283                    identifier: &format!("input_tokens/{}", event.idempotency_key),
284                    event_name: &model.input_tokens_price.meter_event_name,
285                    payload: StripeCreateMeterEventPayload {
286                        value: event.input_tokens as u64,
287                        stripe_customer_id: customer_id,
288                    },
289                    timestamp: Some(timestamp),
290                },
291            )
292            .await?;
293        }
294
295        if event.input_cache_creation_tokens > 0 {
296            StripeMeterEvent::create(
297                &self.client,
298                StripeCreateMeterEventParams {
299                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
300                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
301                    payload: StripeCreateMeterEventPayload {
302                        value: event.input_cache_creation_tokens as u64,
303                        stripe_customer_id: customer_id,
304                    },
305                    timestamp: Some(timestamp),
306                },
307            )
308            .await?;
309        }
310
311        if event.input_cache_read_tokens > 0 {
312            StripeMeterEvent::create(
313                &self.client,
314                StripeCreateMeterEventParams {
315                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
316                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
317                    payload: StripeCreateMeterEventPayload {
318                        value: event.input_cache_read_tokens as u64,
319                        stripe_customer_id: customer_id,
320                    },
321                    timestamp: Some(timestamp),
322                },
323            )
324            .await?;
325        }
326
327        if event.output_tokens > 0 {
328            StripeMeterEvent::create(
329                &self.client,
330                StripeCreateMeterEventParams {
331                    identifier: &format!("output_tokens/{}", event.idempotency_key),
332                    event_name: &model.output_tokens_price.meter_event_name,
333                    payload: StripeCreateMeterEventPayload {
334                        value: event.output_tokens as u64,
335                        stripe_customer_id: customer_id,
336                    },
337                    timestamp: Some(timestamp),
338                },
339            )
340            .await?;
341        }
342
343        Ok(())
344    }
345
346    pub async fn checkout(
347        &self,
348        customer_id: stripe::CustomerId,
349        github_login: &str,
350        model: &StripeModel,
351        success_url: &str,
352    ) -> Result<String> {
353        let first_of_next_month = Utc::now()
354            .checked_add_months(chrono::Months::new(1))
355            .unwrap()
356            .with_day(1)
357            .unwrap();
358
359        let mut params = stripe::CreateCheckoutSession::new();
360        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
361        params.customer = Some(customer_id);
362        params.client_reference_id = Some(github_login);
363        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
364            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
365            ..Default::default()
366        });
367        params.line_items = Some(
368            [
369                &model.input_tokens_price.id,
370                &model.input_cache_creation_tokens_price.id,
371                &model.input_cache_read_tokens_price.id,
372                &model.output_tokens_price.id,
373            ]
374            .into_iter()
375            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
376                price: Some(price_id.to_string()),
377                ..Default::default()
378            })
379            .collect(),
380        );
381        params.success_url = Some(success_url);
382
383        let session = stripe::CheckoutSession::create(&self.client, params).await?;
384        Ok(session.url.context("no checkout session URL")?)
385    }
386
387    pub async fn checkout_with_price(
388        &self,
389        price_id: PriceId,
390        customer_id: stripe::CustomerId,
391        github_login: &str,
392        success_url: &str,
393    ) -> Result<String> {
394        let mut params = stripe::CreateCheckoutSession::new();
395        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
396        params.customer = Some(customer_id);
397        params.client_reference_id = Some(github_login);
398        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
399            price: Some(price_id.to_string()),
400            quantity: Some(1),
401            ..Default::default()
402        }]);
403        params.success_url = Some(success_url);
404
405        let session = stripe::CheckoutSession::create(&self.client, params).await?;
406        Ok(session.url.context("no checkout session URL")?)
407    }
408
409    pub async fn checkout_with_zed_pro_trial(
410        &self,
411        zed_pro_price_id: PriceId,
412        customer_id: stripe::CustomerId,
413        github_login: &str,
414        success_url: &str,
415    ) -> Result<String> {
416        let mut params = stripe::CreateCheckoutSession::new();
417        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
418            trial_period_days: Some(14),
419            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
420                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
421                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
422                }
423            }),
424            ..Default::default()
425        });
426        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
427        params.payment_method_collection =
428            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
429        params.customer = Some(customer_id);
430        params.client_reference_id = Some(github_login);
431        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
432            price: Some(zed_pro_price_id.to_string()),
433            quantity: Some(1),
434            ..Default::default()
435        }]);
436        params.success_url = Some(success_url);
437
438        let session = stripe::CheckoutSession::create(&self.client, params).await?;
439        Ok(session.url.context("no checkout session URL")?)
440    }
441}
442
443#[derive(Serialize)]
444struct DefaultAggregation {
445    formula: &'static str,
446}
447
448#[derive(Serialize)]
449struct StripeCreateMeterParams<'a> {
450    default_aggregation: DefaultAggregation,
451    display_name: String,
452    event_name: &'a str,
453}
454
455#[derive(Clone, Deserialize)]
456struct StripeMeter {
457    id: String,
458    event_name: String,
459}
460
461impl StripeMeter {
462    pub fn create(
463        client: &stripe::Client,
464        params: StripeCreateMeterParams,
465    ) -> stripe::Response<Self> {
466        client.post_form("/billing/meters", params)
467    }
468
469    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
470        #[derive(Serialize)]
471        struct Params {
472            #[serde(skip_serializing_if = "Option::is_none")]
473            limit: Option<u64>,
474        }
475
476        client.get_query("/billing/meters", Params { limit: Some(100) })
477    }
478}
479
480#[derive(Deserialize)]
481struct StripeMeterEvent {
482    identifier: String,
483}
484
485impl StripeMeterEvent {
486    pub async fn create(
487        client: &stripe::Client,
488        params: StripeCreateMeterEventParams<'_>,
489    ) -> Result<Self, stripe::StripeError> {
490        let identifier = params.identifier;
491        match client.post_form("/billing/meter_events", params).await {
492            Ok(event) => Ok(event),
493            Err(stripe::StripeError::Stripe(error)) => {
494                if error.http_status == 400
495                    && error
496                        .message
497                        .as_ref()
498                        .map_or(false, |message| message.contains(identifier))
499                {
500                    Ok(Self {
501                        identifier: identifier.to_string(),
502                    })
503                } else {
504                    Err(stripe::StripeError::Stripe(error))
505                }
506            }
507            Err(error) => Err(error),
508        }
509    }
510}
511
512#[derive(Serialize)]
513struct StripeCreateMeterEventParams<'a> {
514    identifier: &'a str,
515    event_name: &'a str,
516    payload: StripeCreateMeterEventPayload<'a>,
517    timestamp: Option<i64>,
518}
519
520#[derive(Serialize)]
521struct StripeCreateMeterEventPayload<'a> {
522    value: u64,
523    stripe_customer_id: &'a stripe::CustomerId,
524}
525
526fn subscription_contains_price(
527    subscription: &stripe::Subscription,
528    price_id: &stripe::PriceId,
529) -> bool {
530    subscription.items.data.iter().any(|item| {
531        item.price
532            .as_ref()
533            .map_or(false, |price| price.id == *price_id)
534    })
535}