stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::{Cents, Result, llm};
  4use anyhow::{Context as _, anyhow};
  5use chrono::{Datelike, Utc};
  6use collections::HashMap;
  7use serde::{Deserialize, Serialize};
  8use stripe::PriceId;
  9use tokio::sync::RwLock;
 10use uuid::Uuid;
 11
 12pub struct StripeBilling {
 13    state: RwLock<StripeBillingState>,
 14    client: Arc<stripe::Client>,
 15}
 16
 17#[derive(Default)]
 18struct StripeBillingState {
 19    meters_by_event_name: HashMap<String, StripeMeter>,
 20    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 21    prices_by_lookup_key: HashMap<String, stripe::Price>,
 22}
 23
 24pub struct StripeModelTokenPrices {
 25    input_tokens_price: StripeBillingPrice,
 26    input_cache_creation_tokens_price: StripeBillingPrice,
 27    input_cache_read_tokens_price: StripeBillingPrice,
 28    output_tokens_price: StripeBillingPrice,
 29}
 30
 31struct StripeBillingPrice {
 32    id: stripe::PriceId,
 33    meter_event_name: String,
 34}
 35
 36impl StripeBilling {
 37    pub fn new(client: Arc<stripe::Client>) -> Self {
 38        Self {
 39            client,
 40            state: RwLock::default(),
 41        }
 42    }
 43
 44    pub async fn initialize(&self) -> Result<()> {
 45        log::info!("StripeBilling: initializing");
 46
 47        let mut state = self.state.write().await;
 48
 49        let (meters, prices) = futures::try_join!(
 50            StripeMeter::list(&self.client),
 51            stripe::Price::list(
 52                &self.client,
 53                &stripe::ListPrices {
 54                    limit: Some(100),
 55                    ..Default::default()
 56                }
 57            )
 58        )?;
 59
 60        for meter in meters.data {
 61            state
 62                .meters_by_event_name
 63                .insert(meter.event_name.clone(), meter);
 64        }
 65
 66        for price in prices.data {
 67            if let Some(lookup_key) = price.lookup_key.clone() {
 68                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 69            }
 70
 71            if let Some(recurring) = price.recurring {
 72                if let Some(meter) = recurring.meter {
 73                    state.price_ids_by_meter_id.insert(meter, price.id);
 74                }
 75            }
 76        }
 77
 78        log::info!("StripeBilling: initialized");
 79
 80        Ok(())
 81    }
 82
 83    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
 84        self.state
 85            .read()
 86            .await
 87            .prices_by_lookup_key
 88            .get(lookup_key)
 89            .cloned()
 90            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 91    }
 92
 93    pub async fn register_model_for_token_based_usage(
 94        &self,
 95        model: &llm::db::model::Model,
 96    ) -> Result<StripeModelTokenPrices> {
 97        let input_tokens_price = self
 98            .get_or_insert_token_price(
 99                &format!("model_{}/input_tokens", model.id),
100                &format!("{} (Input Tokens)", model.name),
101                Cents::new(model.price_per_million_input_tokens as u32),
102            )
103            .await?;
104        let input_cache_creation_tokens_price = self
105            .get_or_insert_token_price(
106                &format!("model_{}/input_cache_creation_tokens", model.id),
107                &format!("{} (Input Cache Creation Tokens)", model.name),
108                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
109            )
110            .await?;
111        let input_cache_read_tokens_price = self
112            .get_or_insert_token_price(
113                &format!("model_{}/input_cache_read_tokens", model.id),
114                &format!("{} (Input Cache Read Tokens)", model.name),
115                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
116            )
117            .await?;
118        let output_tokens_price = self
119            .get_or_insert_token_price(
120                &format!("model_{}/output_tokens", model.id),
121                &format!("{} (Output Tokens)", model.name),
122                Cents::new(model.price_per_million_output_tokens as u32),
123            )
124            .await?;
125        Ok(StripeModelTokenPrices {
126            input_tokens_price,
127            input_cache_creation_tokens_price,
128            input_cache_read_tokens_price,
129            output_tokens_price,
130        })
131    }
132
133    async fn get_or_insert_token_price(
134        &self,
135        meter_event_name: &str,
136        price_description: &str,
137        price_per_million_tokens: Cents,
138    ) -> Result<StripeBillingPrice> {
139        // Fast code path when the meter and the price already exist.
140        {
141            let state = self.state.read().await;
142            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
143                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
144                    return Ok(StripeBillingPrice {
145                        id: price_id.clone(),
146                        meter_event_name: meter_event_name.to_string(),
147                    });
148                }
149            }
150        }
151
152        let mut state = self.state.write().await;
153        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
154            meter.clone()
155        } else {
156            let meter = StripeMeter::create(
157                &self.client,
158                StripeCreateMeterParams {
159                    default_aggregation: DefaultAggregation { formula: "sum" },
160                    display_name: price_description.to_string(),
161                    event_name: meter_event_name,
162                },
163            )
164            .await?;
165            state
166                .meters_by_event_name
167                .insert(meter_event_name.to_string(), meter.clone());
168            meter
169        };
170
171        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
172            price_id.clone()
173        } else {
174            let price = stripe::Price::create(
175                &self.client,
176                stripe::CreatePrice {
177                    active: Some(true),
178                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
179                    currency: stripe::Currency::USD,
180                    currency_options: None,
181                    custom_unit_amount: None,
182                    expand: &[],
183                    lookup_key: None,
184                    metadata: None,
185                    nickname: None,
186                    product: None,
187                    product_data: Some(stripe::CreatePriceProductData {
188                        id: None,
189                        active: Some(true),
190                        metadata: None,
191                        name: price_description.to_string(),
192                        statement_descriptor: None,
193                        tax_code: None,
194                        unit_label: None,
195                    }),
196                    recurring: Some(stripe::CreatePriceRecurring {
197                        aggregate_usage: None,
198                        interval: stripe::CreatePriceRecurringInterval::Month,
199                        interval_count: None,
200                        trial_period_days: None,
201                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
202                        meter: Some(meter.id.clone()),
203                    }),
204                    tax_behavior: None,
205                    tiers: None,
206                    tiers_mode: None,
207                    transfer_lookup_key: None,
208                    transform_quantity: None,
209                    unit_amount: None,
210                    unit_amount_decimal: Some(&format!(
211                        "{:.12}",
212                        price_per_million_tokens.0 as f64 / 1_000_000f64
213                    )),
214                },
215            )
216            .await?;
217            state
218                .price_ids_by_meter_id
219                .insert(meter.id, price.id.clone());
220            price.id
221        };
222
223        Ok(StripeBillingPrice {
224            id: price_id,
225            meter_event_name: meter_event_name.to_string(),
226        })
227    }
228
229    pub async fn subscribe_to_price(
230        &self,
231        subscription_id: &stripe::SubscriptionId,
232        price_id: &stripe::PriceId,
233    ) -> Result<()> {
234        let subscription =
235            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
236
237        if subscription_contains_price(&subscription, price_id) {
238            return Ok(());
239        }
240
241        stripe::Subscription::update(
242            &self.client,
243            subscription_id,
244            stripe::UpdateSubscription {
245                items: Some(vec![stripe::UpdateSubscriptionItems {
246                    price: Some(price_id.to_string()),
247                    ..Default::default()
248                }]),
249                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
250                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
251                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
252                    },
253                }),
254                ..Default::default()
255            },
256        )
257        .await?;
258
259        Ok(())
260    }
261
262    pub async fn subscribe_to_model(
263        &self,
264        subscription_id: &stripe::SubscriptionId,
265        model: &StripeModelTokenPrices,
266    ) -> Result<()> {
267        let subscription =
268            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
269
270        let mut items = Vec::new();
271
272        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
273            items.push(stripe::UpdateSubscriptionItems {
274                price: Some(model.input_tokens_price.id.to_string()),
275                ..Default::default()
276            });
277        }
278
279        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
280        {
281            items.push(stripe::UpdateSubscriptionItems {
282                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
283                ..Default::default()
284            });
285        }
286
287        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
288            items.push(stripe::UpdateSubscriptionItems {
289                price: Some(model.input_cache_read_tokens_price.id.to_string()),
290                ..Default::default()
291            });
292        }
293
294        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
295            items.push(stripe::UpdateSubscriptionItems {
296                price: Some(model.output_tokens_price.id.to_string()),
297                ..Default::default()
298            });
299        }
300
301        if !items.is_empty() {
302            items.extend(subscription.items.data.iter().map(|item| {
303                stripe::UpdateSubscriptionItems {
304                    id: Some(item.id.to_string()),
305                    ..Default::default()
306                }
307            }));
308
309            stripe::Subscription::update(
310                &self.client,
311                subscription_id,
312                stripe::UpdateSubscription {
313                    items: Some(items),
314                    ..Default::default()
315                },
316            )
317            .await?;
318        }
319
320        Ok(())
321    }
322
323    pub async fn bill_model_token_usage(
324        &self,
325        customer_id: &stripe::CustomerId,
326        model: &StripeModelTokenPrices,
327        event: &llm::db::billing_event::Model,
328    ) -> Result<()> {
329        let timestamp = Utc::now().timestamp();
330
331        if event.input_tokens > 0 {
332            StripeMeterEvent::create(
333                &self.client,
334                StripeCreateMeterEventParams {
335                    identifier: &format!("input_tokens/{}", event.idempotency_key),
336                    event_name: &model.input_tokens_price.meter_event_name,
337                    payload: StripeCreateMeterEventPayload {
338                        value: event.input_tokens as u64,
339                        stripe_customer_id: customer_id,
340                    },
341                    timestamp: Some(timestamp),
342                },
343            )
344            .await?;
345        }
346
347        if event.input_cache_creation_tokens > 0 {
348            StripeMeterEvent::create(
349                &self.client,
350                StripeCreateMeterEventParams {
351                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
352                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
353                    payload: StripeCreateMeterEventPayload {
354                        value: event.input_cache_creation_tokens as u64,
355                        stripe_customer_id: customer_id,
356                    },
357                    timestamp: Some(timestamp),
358                },
359            )
360            .await?;
361        }
362
363        if event.input_cache_read_tokens > 0 {
364            StripeMeterEvent::create(
365                &self.client,
366                StripeCreateMeterEventParams {
367                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
368                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
369                    payload: StripeCreateMeterEventPayload {
370                        value: event.input_cache_read_tokens as u64,
371                        stripe_customer_id: customer_id,
372                    },
373                    timestamp: Some(timestamp),
374                },
375            )
376            .await?;
377        }
378
379        if event.output_tokens > 0 {
380            StripeMeterEvent::create(
381                &self.client,
382                StripeCreateMeterEventParams {
383                    identifier: &format!("output_tokens/{}", event.idempotency_key),
384                    event_name: &model.output_tokens_price.meter_event_name,
385                    payload: StripeCreateMeterEventPayload {
386                        value: event.output_tokens as u64,
387                        stripe_customer_id: customer_id,
388                    },
389                    timestamp: Some(timestamp),
390                },
391            )
392            .await?;
393        }
394
395        Ok(())
396    }
397
398    pub async fn bill_model_request_usage(
399        &self,
400        customer_id: &stripe::CustomerId,
401        event_name: &str,
402        requests: i32,
403    ) -> Result<()> {
404        let timestamp = Utc::now().timestamp();
405        let idempotency_key = Uuid::new_v4();
406
407        StripeMeterEvent::create(
408            &self.client,
409            StripeCreateMeterEventParams {
410                identifier: &format!("model_requests/{}", idempotency_key),
411                event_name,
412                payload: StripeCreateMeterEventPayload {
413                    value: requests as u64,
414                    stripe_customer_id: customer_id,
415                },
416                timestamp: Some(timestamp),
417            },
418        )
419        .await?;
420
421        Ok(())
422    }
423
424    pub async fn checkout(
425        &self,
426        customer_id: stripe::CustomerId,
427        github_login: &str,
428        model: &StripeModelTokenPrices,
429        success_url: &str,
430    ) -> Result<String> {
431        let first_of_next_month = Utc::now()
432            .checked_add_months(chrono::Months::new(1))
433            .unwrap()
434            .with_day(1)
435            .unwrap();
436
437        let mut params = stripe::CreateCheckoutSession::new();
438        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
439        params.customer = Some(customer_id);
440        params.client_reference_id = Some(github_login);
441        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
442            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
443            ..Default::default()
444        });
445        params.line_items = Some(
446            [
447                &model.input_tokens_price.id,
448                &model.input_cache_creation_tokens_price.id,
449                &model.input_cache_read_tokens_price.id,
450                &model.output_tokens_price.id,
451            ]
452            .into_iter()
453            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
454                price: Some(price_id.to_string()),
455                ..Default::default()
456            })
457            .collect(),
458        );
459        params.success_url = Some(success_url);
460
461        let session = stripe::CheckoutSession::create(&self.client, params).await?;
462        Ok(session.url.context("no checkout session URL")?)
463    }
464
465    pub async fn checkout_with_price(
466        &self,
467        price_id: PriceId,
468        customer_id: stripe::CustomerId,
469        github_login: &str,
470        success_url: &str,
471    ) -> Result<String> {
472        let mut params = stripe::CreateCheckoutSession::new();
473        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
474        params.customer = Some(customer_id);
475        params.client_reference_id = Some(github_login);
476        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
477            price: Some(price_id.to_string()),
478            quantity: Some(1),
479            ..Default::default()
480        }]);
481        params.success_url = Some(success_url);
482
483        let session = stripe::CheckoutSession::create(&self.client, params).await?;
484        Ok(session.url.context("no checkout session URL")?)
485    }
486
487    pub async fn checkout_with_zed_pro_trial(
488        &self,
489        zed_pro_price_id: PriceId,
490        customer_id: stripe::CustomerId,
491        github_login: &str,
492        success_url: &str,
493    ) -> Result<String> {
494        let mut params = stripe::CreateCheckoutSession::new();
495        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
496            trial_period_days: Some(14),
497            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
498                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
499                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
500                }
501            }),
502            ..Default::default()
503        });
504        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
505        params.payment_method_collection =
506            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
507        params.customer = Some(customer_id);
508        params.client_reference_id = Some(github_login);
509        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
510            price: Some(zed_pro_price_id.to_string()),
511            quantity: Some(1),
512            ..Default::default()
513        }]);
514        params.success_url = Some(success_url);
515
516        let session = stripe::CheckoutSession::create(&self.client, params).await?;
517        Ok(session.url.context("no checkout session URL")?)
518    }
519}
520
521#[derive(Serialize)]
522struct DefaultAggregation {
523    formula: &'static str,
524}
525
526#[derive(Serialize)]
527struct StripeCreateMeterParams<'a> {
528    default_aggregation: DefaultAggregation,
529    display_name: String,
530    event_name: &'a str,
531}
532
533#[derive(Clone, Deserialize)]
534struct StripeMeter {
535    id: String,
536    event_name: String,
537}
538
539impl StripeMeter {
540    pub fn create(
541        client: &stripe::Client,
542        params: StripeCreateMeterParams,
543    ) -> stripe::Response<Self> {
544        client.post_form("/billing/meters", params)
545    }
546
547    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
548        #[derive(Serialize)]
549        struct Params {
550            #[serde(skip_serializing_if = "Option::is_none")]
551            limit: Option<u64>,
552        }
553
554        client.get_query("/billing/meters", Params { limit: Some(100) })
555    }
556}
557
558#[derive(Deserialize)]
559struct StripeMeterEvent {
560    identifier: String,
561}
562
563impl StripeMeterEvent {
564    pub async fn create(
565        client: &stripe::Client,
566        params: StripeCreateMeterEventParams<'_>,
567    ) -> Result<Self, stripe::StripeError> {
568        let identifier = params.identifier;
569        match client.post_form("/billing/meter_events", params).await {
570            Ok(event) => Ok(event),
571            Err(stripe::StripeError::Stripe(error)) => {
572                if error.http_status == 400
573                    && error
574                        .message
575                        .as_ref()
576                        .map_or(false, |message| message.contains(identifier))
577                {
578                    Ok(Self {
579                        identifier: identifier.to_string(),
580                    })
581                } else {
582                    Err(stripe::StripeError::Stripe(error))
583                }
584            }
585            Err(error) => Err(error),
586        }
587    }
588}
589
590#[derive(Serialize)]
591struct StripeCreateMeterEventParams<'a> {
592    identifier: &'a str,
593    event_name: &'a str,
594    payload: StripeCreateMeterEventPayload<'a>,
595    timestamp: Option<i64>,
596}
597
598#[derive(Serialize)]
599struct StripeCreateMeterEventPayload<'a> {
600    value: u64,
601    stripe_customer_id: &'a stripe::CustomerId,
602}
603
604fn subscription_contains_price(
605    subscription: &stripe::Subscription,
606    price_id: &stripe::PriceId,
607) -> bool {
608    subscription.items.data.iter().any(|item| {
609        item.price
610            .as_ref()
611            .map_or(false, |price| price.id == *price_id)
612    })
613}