stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::{Cents, Result, llm};
  4use anyhow::{Context as _, anyhow};
  5use chrono::{Datelike, Utc};
  6use collections::HashMap;
  7use serde::{Deserialize, Serialize};
  8use tokio::sync::RwLock;
  9
 10pub struct StripeBilling {
 11    state: RwLock<StripeBillingState>,
 12    client: Arc<stripe::Client>,
 13    zed_pro_price_id: Option<String>,
 14}
 15
 16#[derive(Default)]
 17struct StripeBillingState {
 18    meters_by_event_name: HashMap<String, StripeMeter>,
 19    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 20}
 21
 22pub struct StripeModel {
 23    input_tokens_price: StripeBillingPrice,
 24    input_cache_creation_tokens_price: StripeBillingPrice,
 25    input_cache_read_tokens_price: StripeBillingPrice,
 26    output_tokens_price: StripeBillingPrice,
 27}
 28
 29struct StripeBillingPrice {
 30    id: stripe::PriceId,
 31    meter_event_name: String,
 32}
 33
 34impl StripeBilling {
 35    pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
 36        Self {
 37            client,
 38            state: RwLock::default(),
 39            zed_pro_price_id,
 40        }
 41    }
 42
 43    pub async fn initialize(&self) -> Result<()> {
 44        log::info!("StripeBilling: initializing");
 45
 46        let mut state = self.state.write().await;
 47
 48        let (meters, prices) = futures::try_join!(
 49            StripeMeter::list(&self.client),
 50            stripe::Price::list(
 51                &self.client,
 52                &stripe::ListPrices {
 53                    limit: Some(100),
 54                    ..Default::default()
 55                }
 56            )
 57        )?;
 58
 59        for meter in meters.data {
 60            state
 61                .meters_by_event_name
 62                .insert(meter.event_name.clone(), meter);
 63        }
 64
 65        for price in prices.data {
 66            if let Some(recurring) = price.recurring {
 67                if let Some(meter) = recurring.meter {
 68                    state.price_ids_by_meter_id.insert(meter, price.id);
 69                }
 70            }
 71        }
 72
 73        log::info!("StripeBilling: initialized");
 74
 75        Ok(())
 76    }
 77
 78    pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
 79        let input_tokens_price = self
 80            .get_or_insert_price(
 81                &format!("model_{}/input_tokens", model.id),
 82                &format!("{} (Input Tokens)", model.name),
 83                Cents::new(model.price_per_million_input_tokens as u32),
 84            )
 85            .await?;
 86        let input_cache_creation_tokens_price = self
 87            .get_or_insert_price(
 88                &format!("model_{}/input_cache_creation_tokens", model.id),
 89                &format!("{} (Input Cache Creation Tokens)", model.name),
 90                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
 91            )
 92            .await?;
 93        let input_cache_read_tokens_price = self
 94            .get_or_insert_price(
 95                &format!("model_{}/input_cache_read_tokens", model.id),
 96                &format!("{} (Input Cache Read Tokens)", model.name),
 97                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
 98            )
 99            .await?;
100        let output_tokens_price = self
101            .get_or_insert_price(
102                &format!("model_{}/output_tokens", model.id),
103                &format!("{} (Output Tokens)", model.name),
104                Cents::new(model.price_per_million_output_tokens as u32),
105            )
106            .await?;
107        Ok(StripeModel {
108            input_tokens_price,
109            input_cache_creation_tokens_price,
110            input_cache_read_tokens_price,
111            output_tokens_price,
112        })
113    }
114
115    async fn get_or_insert_price(
116        &self,
117        meter_event_name: &str,
118        price_description: &str,
119        price_per_million_tokens: Cents,
120    ) -> Result<StripeBillingPrice> {
121        // Fast code path when the meter and the price already exist.
122        {
123            let state = self.state.read().await;
124            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
125                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
126                    return Ok(StripeBillingPrice {
127                        id: price_id.clone(),
128                        meter_event_name: meter_event_name.to_string(),
129                    });
130                }
131            }
132        }
133
134        let mut state = self.state.write().await;
135        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
136            meter.clone()
137        } else {
138            let meter = StripeMeter::create(
139                &self.client,
140                StripeCreateMeterParams {
141                    default_aggregation: DefaultAggregation { formula: "sum" },
142                    display_name: price_description.to_string(),
143                    event_name: meter_event_name,
144                },
145            )
146            .await?;
147            state
148                .meters_by_event_name
149                .insert(meter_event_name.to_string(), meter.clone());
150            meter
151        };
152
153        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
154            price_id.clone()
155        } else {
156            let price = stripe::Price::create(
157                &self.client,
158                stripe::CreatePrice {
159                    active: Some(true),
160                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
161                    currency: stripe::Currency::USD,
162                    currency_options: None,
163                    custom_unit_amount: None,
164                    expand: &[],
165                    lookup_key: None,
166                    metadata: None,
167                    nickname: None,
168                    product: None,
169                    product_data: Some(stripe::CreatePriceProductData {
170                        id: None,
171                        active: Some(true),
172                        metadata: None,
173                        name: price_description.to_string(),
174                        statement_descriptor: None,
175                        tax_code: None,
176                        unit_label: None,
177                    }),
178                    recurring: Some(stripe::CreatePriceRecurring {
179                        aggregate_usage: None,
180                        interval: stripe::CreatePriceRecurringInterval::Month,
181                        interval_count: None,
182                        trial_period_days: None,
183                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
184                        meter: Some(meter.id.clone()),
185                    }),
186                    tax_behavior: None,
187                    tiers: None,
188                    tiers_mode: None,
189                    transfer_lookup_key: None,
190                    transform_quantity: None,
191                    unit_amount: None,
192                    unit_amount_decimal: Some(&format!(
193                        "{:.12}",
194                        price_per_million_tokens.0 as f64 / 1_000_000f64
195                    )),
196                },
197            )
198            .await?;
199            state
200                .price_ids_by_meter_id
201                .insert(meter.id, price.id.clone());
202            price.id
203        };
204
205        Ok(StripeBillingPrice {
206            id: price_id,
207            meter_event_name: meter_event_name.to_string(),
208        })
209    }
210
211    pub async fn subscribe_to_model(
212        &self,
213        subscription_id: &stripe::SubscriptionId,
214        model: &StripeModel,
215    ) -> Result<()> {
216        let subscription =
217            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
218
219        let mut items = Vec::new();
220
221        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
222            items.push(stripe::UpdateSubscriptionItems {
223                price: Some(model.input_tokens_price.id.to_string()),
224                ..Default::default()
225            });
226        }
227
228        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
229        {
230            items.push(stripe::UpdateSubscriptionItems {
231                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
232                ..Default::default()
233            });
234        }
235
236        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
237            items.push(stripe::UpdateSubscriptionItems {
238                price: Some(model.input_cache_read_tokens_price.id.to_string()),
239                ..Default::default()
240            });
241        }
242
243        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
244            items.push(stripe::UpdateSubscriptionItems {
245                price: Some(model.output_tokens_price.id.to_string()),
246                ..Default::default()
247            });
248        }
249
250        if !items.is_empty() {
251            items.extend(subscription.items.data.iter().map(|item| {
252                stripe::UpdateSubscriptionItems {
253                    id: Some(item.id.to_string()),
254                    ..Default::default()
255                }
256            }));
257
258            stripe::Subscription::update(
259                &self.client,
260                subscription_id,
261                stripe::UpdateSubscription {
262                    items: Some(items),
263                    ..Default::default()
264                },
265            )
266            .await?;
267        }
268
269        Ok(())
270    }
271
272    pub async fn bill_model_usage(
273        &self,
274        customer_id: &stripe::CustomerId,
275        model: &StripeModel,
276        event: &llm::db::billing_event::Model,
277    ) -> Result<()> {
278        let timestamp = Utc::now().timestamp();
279
280        if event.input_tokens > 0 {
281            StripeMeterEvent::create(
282                &self.client,
283                StripeCreateMeterEventParams {
284                    identifier: &format!("input_tokens/{}", event.idempotency_key),
285                    event_name: &model.input_tokens_price.meter_event_name,
286                    payload: StripeCreateMeterEventPayload {
287                        value: event.input_tokens as u64,
288                        stripe_customer_id: customer_id,
289                    },
290                    timestamp: Some(timestamp),
291                },
292            )
293            .await?;
294        }
295
296        if event.input_cache_creation_tokens > 0 {
297            StripeMeterEvent::create(
298                &self.client,
299                StripeCreateMeterEventParams {
300                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
301                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
302                    payload: StripeCreateMeterEventPayload {
303                        value: event.input_cache_creation_tokens as u64,
304                        stripe_customer_id: customer_id,
305                    },
306                    timestamp: Some(timestamp),
307                },
308            )
309            .await?;
310        }
311
312        if event.input_cache_read_tokens > 0 {
313            StripeMeterEvent::create(
314                &self.client,
315                StripeCreateMeterEventParams {
316                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
317                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
318                    payload: StripeCreateMeterEventPayload {
319                        value: event.input_cache_read_tokens as u64,
320                        stripe_customer_id: customer_id,
321                    },
322                    timestamp: Some(timestamp),
323                },
324            )
325            .await?;
326        }
327
328        if event.output_tokens > 0 {
329            StripeMeterEvent::create(
330                &self.client,
331                StripeCreateMeterEventParams {
332                    identifier: &format!("output_tokens/{}", event.idempotency_key),
333                    event_name: &model.output_tokens_price.meter_event_name,
334                    payload: StripeCreateMeterEventPayload {
335                        value: event.output_tokens as u64,
336                        stripe_customer_id: customer_id,
337                    },
338                    timestamp: Some(timestamp),
339                },
340            )
341            .await?;
342        }
343
344        Ok(())
345    }
346
347    pub async fn checkout(
348        &self,
349        customer_id: stripe::CustomerId,
350        github_login: &str,
351        model: &StripeModel,
352        success_url: &str,
353    ) -> Result<String> {
354        let first_of_next_month = Utc::now()
355            .checked_add_months(chrono::Months::new(1))
356            .unwrap()
357            .with_day(1)
358            .unwrap();
359
360        let mut params = stripe::CreateCheckoutSession::new();
361        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
362        params.customer = Some(customer_id);
363        params.client_reference_id = Some(github_login);
364        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
365            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
366            ..Default::default()
367        });
368        params.line_items = Some(
369            [
370                &model.input_tokens_price.id,
371                &model.input_cache_creation_tokens_price.id,
372                &model.input_cache_read_tokens_price.id,
373                &model.output_tokens_price.id,
374            ]
375            .into_iter()
376            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
377                price: Some(price_id.to_string()),
378                ..Default::default()
379            })
380            .collect(),
381        );
382        params.success_url = Some(success_url);
383
384        let session = stripe::CheckoutSession::create(&self.client, params).await?;
385        Ok(session.url.context("no checkout session URL")?)
386    }
387
388    pub async fn checkout_with_zed_pro(
389        &self,
390        customer_id: stripe::CustomerId,
391        github_login: &str,
392        success_url: &str,
393    ) -> Result<String> {
394        let zed_pro_price_id = self
395            .zed_pro_price_id
396            .as_ref()
397            .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
398
399        let mut params = stripe::CreateCheckoutSession::new();
400        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
401        params.customer = Some(customer_id);
402        params.client_reference_id = Some(github_login);
403        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
404            price: Some(zed_pro_price_id.clone()),
405            quantity: Some(1),
406            ..Default::default()
407        }]);
408        params.success_url = Some(success_url);
409
410        let session = stripe::CheckoutSession::create(&self.client, params).await?;
411        Ok(session.url.context("no checkout session URL")?)
412    }
413}
414
415#[derive(Serialize)]
416struct DefaultAggregation {
417    formula: &'static str,
418}
419
420#[derive(Serialize)]
421struct StripeCreateMeterParams<'a> {
422    default_aggregation: DefaultAggregation,
423    display_name: String,
424    event_name: &'a str,
425}
426
427#[derive(Clone, Deserialize)]
428struct StripeMeter {
429    id: String,
430    event_name: String,
431}
432
433impl StripeMeter {
434    pub fn create(
435        client: &stripe::Client,
436        params: StripeCreateMeterParams,
437    ) -> stripe::Response<Self> {
438        client.post_form("/billing/meters", params)
439    }
440
441    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
442        #[derive(Serialize)]
443        struct Params {
444            #[serde(skip_serializing_if = "Option::is_none")]
445            limit: Option<u64>,
446        }
447
448        client.get_query("/billing/meters", Params { limit: Some(100) })
449    }
450}
451
452#[derive(Deserialize)]
453struct StripeMeterEvent {
454    identifier: String,
455}
456
457impl StripeMeterEvent {
458    pub async fn create(
459        client: &stripe::Client,
460        params: StripeCreateMeterEventParams<'_>,
461    ) -> Result<Self, stripe::StripeError> {
462        let identifier = params.identifier;
463        match client.post_form("/billing/meter_events", params).await {
464            Ok(event) => Ok(event),
465            Err(stripe::StripeError::Stripe(error)) => {
466                if error.http_status == 400
467                    && error
468                        .message
469                        .as_ref()
470                        .map_or(false, |message| message.contains(identifier))
471                {
472                    Ok(Self {
473                        identifier: identifier.to_string(),
474                    })
475                } else {
476                    Err(stripe::StripeError::Stripe(error))
477                }
478            }
479            Err(error) => Err(error),
480        }
481    }
482}
483
484#[derive(Serialize)]
485struct StripeCreateMeterEventParams<'a> {
486    identifier: &'a str,
487    event_name: &'a str,
488    payload: StripeCreateMeterEventPayload<'a>,
489    timestamp: Option<i64>,
490}
491
492#[derive(Serialize)]
493struct StripeCreateMeterEventPayload<'a> {
494    value: u64,
495    stripe_customer_id: &'a stripe::CustomerId,
496}
497
498fn subscription_contains_price(
499    subscription: &stripe::Subscription,
500    price_id: &stripe::PriceId,
501) -> bool {
502    subscription.items.data.iter().any(|item| {
503        item.price
504            .as_ref()
505            .map_or(false, |price| price.id == *price_id)
506    })
507}