stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
  4use crate::{Cents, Result};
  5use anyhow::{Context as _, anyhow};
  6use chrono::{Datelike, Utc};
  7use collections::HashMap;
  8use serde::{Deserialize, Serialize};
  9use stripe::PriceId;
 10use tokio::sync::RwLock;
 11use uuid::Uuid;
 12
 13pub struct StripeBilling {
 14    state: RwLock<StripeBillingState>,
 15    client: Arc<stripe::Client>,
 16}
 17
 18#[derive(Default)]
 19struct StripeBillingState {
 20    meters_by_event_name: HashMap<String, StripeMeter>,
 21    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 22    prices_by_lookup_key: HashMap<String, stripe::Price>,
 23}
 24
 25pub struct StripeModelTokenPrices {
 26    input_tokens_price: StripeBillingPrice,
 27    input_cache_creation_tokens_price: StripeBillingPrice,
 28    input_cache_read_tokens_price: StripeBillingPrice,
 29    output_tokens_price: StripeBillingPrice,
 30}
 31
 32struct StripeBillingPrice {
 33    id: stripe::PriceId,
 34    meter_event_name: String,
 35}
 36
 37impl StripeBilling {
 38    pub fn new(client: Arc<stripe::Client>) -> Self {
 39        Self {
 40            client,
 41            state: RwLock::default(),
 42        }
 43    }
 44
 45    pub async fn initialize(&self) -> Result<()> {
 46        log::info!("StripeBilling: initializing");
 47
 48        let mut state = self.state.write().await;
 49
 50        let (meters, prices) = futures::try_join!(
 51            StripeMeter::list(&self.client),
 52            stripe::Price::list(
 53                &self.client,
 54                &stripe::ListPrices {
 55                    limit: Some(100),
 56                    ..Default::default()
 57                }
 58            )
 59        )?;
 60
 61        for meter in meters.data {
 62            state
 63                .meters_by_event_name
 64                .insert(meter.event_name.clone(), meter);
 65        }
 66
 67        for price in prices.data {
 68            if let Some(lookup_key) = price.lookup_key.clone() {
 69                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 70            }
 71
 72            if let Some(recurring) = price.recurring {
 73                if let Some(meter) = recurring.meter {
 74                    state.price_ids_by_meter_id.insert(meter, price.id);
 75                }
 76            }
 77        }
 78
 79        log::info!("StripeBilling: initialized");
 80
 81        Ok(())
 82    }
 83
 84    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 85        self.find_price_id_by_lookup_key("zed-pro").await
 86    }
 87
 88    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 89        self.find_price_id_by_lookup_key("zed-free").await
 90    }
 91
 92    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 93        self.state
 94            .read()
 95            .await
 96            .prices_by_lookup_key
 97            .get(lookup_key)
 98            .map(|price| price.id.clone())
 99            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
100    }
101
102    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
103        self.state
104            .read()
105            .await
106            .prices_by_lookup_key
107            .get(lookup_key)
108            .cloned()
109            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
110    }
111
112    pub async fn register_model_for_token_based_usage(
113        &self,
114        model: &llm::db::model::Model,
115    ) -> Result<StripeModelTokenPrices> {
116        let input_tokens_price = self
117            .get_or_insert_token_price(
118                &format!("model_{}/input_tokens", model.id),
119                &format!("{} (Input Tokens)", model.name),
120                Cents::new(model.price_per_million_input_tokens as u32),
121            )
122            .await?;
123        let input_cache_creation_tokens_price = self
124            .get_or_insert_token_price(
125                &format!("model_{}/input_cache_creation_tokens", model.id),
126                &format!("{} (Input Cache Creation Tokens)", model.name),
127                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
128            )
129            .await?;
130        let input_cache_read_tokens_price = self
131            .get_or_insert_token_price(
132                &format!("model_{}/input_cache_read_tokens", model.id),
133                &format!("{} (Input Cache Read Tokens)", model.name),
134                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
135            )
136            .await?;
137        let output_tokens_price = self
138            .get_or_insert_token_price(
139                &format!("model_{}/output_tokens", model.id),
140                &format!("{} (Output Tokens)", model.name),
141                Cents::new(model.price_per_million_output_tokens as u32),
142            )
143            .await?;
144        Ok(StripeModelTokenPrices {
145            input_tokens_price,
146            input_cache_creation_tokens_price,
147            input_cache_read_tokens_price,
148            output_tokens_price,
149        })
150    }
151
152    async fn get_or_insert_token_price(
153        &self,
154        meter_event_name: &str,
155        price_description: &str,
156        price_per_million_tokens: Cents,
157    ) -> Result<StripeBillingPrice> {
158        // Fast code path when the meter and the price already exist.
159        {
160            let state = self.state.read().await;
161            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
162                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
163                    return Ok(StripeBillingPrice {
164                        id: price_id.clone(),
165                        meter_event_name: meter_event_name.to_string(),
166                    });
167                }
168            }
169        }
170
171        let mut state = self.state.write().await;
172        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
173            meter.clone()
174        } else {
175            let meter = StripeMeter::create(
176                &self.client,
177                StripeCreateMeterParams {
178                    default_aggregation: DefaultAggregation { formula: "sum" },
179                    display_name: price_description.to_string(),
180                    event_name: meter_event_name,
181                },
182            )
183            .await?;
184            state
185                .meters_by_event_name
186                .insert(meter_event_name.to_string(), meter.clone());
187            meter
188        };
189
190        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
191            price_id.clone()
192        } else {
193            let price = stripe::Price::create(
194                &self.client,
195                stripe::CreatePrice {
196                    active: Some(true),
197                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
198                    currency: stripe::Currency::USD,
199                    currency_options: None,
200                    custom_unit_amount: None,
201                    expand: &[],
202                    lookup_key: None,
203                    metadata: None,
204                    nickname: None,
205                    product: None,
206                    product_data: Some(stripe::CreatePriceProductData {
207                        id: None,
208                        active: Some(true),
209                        metadata: None,
210                        name: price_description.to_string(),
211                        statement_descriptor: None,
212                        tax_code: None,
213                        unit_label: None,
214                    }),
215                    recurring: Some(stripe::CreatePriceRecurring {
216                        aggregate_usage: None,
217                        interval: stripe::CreatePriceRecurringInterval::Month,
218                        interval_count: None,
219                        trial_period_days: None,
220                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
221                        meter: Some(meter.id.clone()),
222                    }),
223                    tax_behavior: None,
224                    tiers: None,
225                    tiers_mode: None,
226                    transfer_lookup_key: None,
227                    transform_quantity: None,
228                    unit_amount: None,
229                    unit_amount_decimal: Some(&format!(
230                        "{:.12}",
231                        price_per_million_tokens.0 as f64 / 1_000_000f64
232                    )),
233                },
234            )
235            .await?;
236            state
237                .price_ids_by_meter_id
238                .insert(meter.id, price.id.clone());
239            price.id
240        };
241
242        Ok(StripeBillingPrice {
243            id: price_id,
244            meter_event_name: meter_event_name.to_string(),
245        })
246    }
247
248    pub async fn subscribe_to_price(
249        &self,
250        subscription_id: &stripe::SubscriptionId,
251        price: &stripe::Price,
252    ) -> Result<()> {
253        let subscription =
254            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
255
256        if subscription_contains_price(&subscription, &price.id) {
257            return Ok(());
258        }
259
260        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
261
262        let price_per_unit = price.unit_amount.unwrap_or_default();
263        let units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
264
265        stripe::Subscription::update(
266            &self.client,
267            subscription_id,
268            stripe::UpdateSubscription {
269                items: Some(vec![stripe::UpdateSubscriptionItems {
270                    price: Some(price.id.to_string()),
271                    billing_thresholds: Some(stripe::SubscriptionItemBillingThresholds {
272                        usage_gte: Some(units_for_billing_threshold),
273                    }),
274                    ..Default::default()
275                }]),
276                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
277                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
278                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
279                    },
280                }),
281                ..Default::default()
282            },
283        )
284        .await?;
285
286        Ok(())
287    }
288
289    pub async fn subscribe_to_model(
290        &self,
291        subscription_id: &stripe::SubscriptionId,
292        model: &StripeModelTokenPrices,
293    ) -> Result<()> {
294        let subscription =
295            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
296
297        let mut items = Vec::new();
298
299        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
300            items.push(stripe::UpdateSubscriptionItems {
301                price: Some(model.input_tokens_price.id.to_string()),
302                ..Default::default()
303            });
304        }
305
306        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
307        {
308            items.push(stripe::UpdateSubscriptionItems {
309                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
310                ..Default::default()
311            });
312        }
313
314        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
315            items.push(stripe::UpdateSubscriptionItems {
316                price: Some(model.input_cache_read_tokens_price.id.to_string()),
317                ..Default::default()
318            });
319        }
320
321        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
322            items.push(stripe::UpdateSubscriptionItems {
323                price: Some(model.output_tokens_price.id.to_string()),
324                ..Default::default()
325            });
326        }
327
328        if !items.is_empty() {
329            items.extend(subscription.items.data.iter().map(|item| {
330                stripe::UpdateSubscriptionItems {
331                    id: Some(item.id.to_string()),
332                    ..Default::default()
333                }
334            }));
335
336            stripe::Subscription::update(
337                &self.client,
338                subscription_id,
339                stripe::UpdateSubscription {
340                    items: Some(items),
341                    ..Default::default()
342                },
343            )
344            .await?;
345        }
346
347        Ok(())
348    }
349
350    pub async fn bill_model_token_usage(
351        &self,
352        customer_id: &stripe::CustomerId,
353        model: &StripeModelTokenPrices,
354        event: &llm::db::billing_event::Model,
355    ) -> Result<()> {
356        let timestamp = Utc::now().timestamp();
357
358        if event.input_tokens > 0 {
359            StripeMeterEvent::create(
360                &self.client,
361                StripeCreateMeterEventParams {
362                    identifier: &format!("input_tokens/{}", event.idempotency_key),
363                    event_name: &model.input_tokens_price.meter_event_name,
364                    payload: StripeCreateMeterEventPayload {
365                        value: event.input_tokens as u64,
366                        stripe_customer_id: customer_id,
367                    },
368                    timestamp: Some(timestamp),
369                },
370            )
371            .await?;
372        }
373
374        if event.input_cache_creation_tokens > 0 {
375            StripeMeterEvent::create(
376                &self.client,
377                StripeCreateMeterEventParams {
378                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
379                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
380                    payload: StripeCreateMeterEventPayload {
381                        value: event.input_cache_creation_tokens as u64,
382                        stripe_customer_id: customer_id,
383                    },
384                    timestamp: Some(timestamp),
385                },
386            )
387            .await?;
388        }
389
390        if event.input_cache_read_tokens > 0 {
391            StripeMeterEvent::create(
392                &self.client,
393                StripeCreateMeterEventParams {
394                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
395                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
396                    payload: StripeCreateMeterEventPayload {
397                        value: event.input_cache_read_tokens as u64,
398                        stripe_customer_id: customer_id,
399                    },
400                    timestamp: Some(timestamp),
401                },
402            )
403            .await?;
404        }
405
406        if event.output_tokens > 0 {
407            StripeMeterEvent::create(
408                &self.client,
409                StripeCreateMeterEventParams {
410                    identifier: &format!("output_tokens/{}", event.idempotency_key),
411                    event_name: &model.output_tokens_price.meter_event_name,
412                    payload: StripeCreateMeterEventPayload {
413                        value: event.output_tokens as u64,
414                        stripe_customer_id: customer_id,
415                    },
416                    timestamp: Some(timestamp),
417                },
418            )
419            .await?;
420        }
421
422        Ok(())
423    }
424
425    pub async fn bill_model_request_usage(
426        &self,
427        customer_id: &stripe::CustomerId,
428        event_name: &str,
429        requests: i32,
430    ) -> Result<()> {
431        let timestamp = Utc::now().timestamp();
432        let idempotency_key = Uuid::new_v4();
433
434        StripeMeterEvent::create(
435            &self.client,
436            StripeCreateMeterEventParams {
437                identifier: &format!("model_requests/{}", idempotency_key),
438                event_name,
439                payload: StripeCreateMeterEventPayload {
440                    value: requests as u64,
441                    stripe_customer_id: customer_id,
442                },
443                timestamp: Some(timestamp),
444            },
445        )
446        .await?;
447
448        Ok(())
449    }
450
451    pub async fn checkout(
452        &self,
453        customer_id: stripe::CustomerId,
454        github_login: &str,
455        model: &StripeModelTokenPrices,
456        success_url: &str,
457    ) -> Result<String> {
458        let first_of_next_month = Utc::now()
459            .checked_add_months(chrono::Months::new(1))
460            .unwrap()
461            .with_day(1)
462            .unwrap();
463
464        let mut params = stripe::CreateCheckoutSession::new();
465        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
466        params.customer = Some(customer_id);
467        params.client_reference_id = Some(github_login);
468        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
469            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
470            ..Default::default()
471        });
472        params.line_items = Some(
473            [
474                &model.input_tokens_price.id,
475                &model.input_cache_creation_tokens_price.id,
476                &model.input_cache_read_tokens_price.id,
477                &model.output_tokens_price.id,
478            ]
479            .into_iter()
480            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
481                price: Some(price_id.to_string()),
482                ..Default::default()
483            })
484            .collect(),
485        );
486        params.success_url = Some(success_url);
487
488        let session = stripe::CheckoutSession::create(&self.client, params).await?;
489        Ok(session.url.context("no checkout session URL")?)
490    }
491
492    pub async fn checkout_with_zed_pro(
493        &self,
494        customer_id: stripe::CustomerId,
495        github_login: &str,
496        success_url: &str,
497    ) -> Result<String> {
498        let zed_pro_price_id = self.zed_pro_price_id().await?;
499
500        let mut params = stripe::CreateCheckoutSession::new();
501        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
502        params.customer = Some(customer_id);
503        params.client_reference_id = Some(github_login);
504        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
505            price: Some(zed_pro_price_id.to_string()),
506            quantity: Some(1),
507            ..Default::default()
508        }]);
509        params.success_url = Some(success_url);
510
511        let session = stripe::CheckoutSession::create(&self.client, params).await?;
512        Ok(session.url.context("no checkout session URL")?)
513    }
514
515    pub async fn checkout_with_zed_pro_trial(
516        &self,
517        customer_id: stripe::CustomerId,
518        github_login: &str,
519        feature_flags: Vec<String>,
520        success_url: &str,
521    ) -> Result<String> {
522        let zed_pro_price_id = self.zed_pro_price_id().await?;
523
524        let eligible_for_extended_trial = feature_flags
525            .iter()
526            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
527
528        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
529
530        let mut subscription_metadata = std::collections::HashMap::new();
531        if eligible_for_extended_trial {
532            subscription_metadata.insert(
533                "promo_feature_flag".to_string(),
534                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
535            );
536        }
537
538        let mut params = stripe::CreateCheckoutSession::new();
539        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
540            trial_period_days: Some(trial_period_days),
541            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
542                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
543                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
544                }
545            }),
546            metadata: if !subscription_metadata.is_empty() {
547                Some(subscription_metadata)
548            } else {
549                None
550            },
551            ..Default::default()
552        });
553        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
554        params.payment_method_collection =
555            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
556        params.customer = Some(customer_id);
557        params.client_reference_id = Some(github_login);
558        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
559            price: Some(zed_pro_price_id.to_string()),
560            quantity: Some(1),
561            ..Default::default()
562        }]);
563        params.success_url = Some(success_url);
564
565        let session = stripe::CheckoutSession::create(&self.client, params).await?;
566        Ok(session.url.context("no checkout session URL")?)
567    }
568
569    pub async fn checkout_with_zed_free(
570        &self,
571        customer_id: stripe::CustomerId,
572        github_login: &str,
573        success_url: &str,
574    ) -> Result<String> {
575        let zed_free_price_id = self.zed_free_price_id().await?;
576
577        let mut params = stripe::CreateCheckoutSession::new();
578        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
579        params.customer = Some(customer_id);
580        params.client_reference_id = Some(github_login);
581        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
582            price: Some(zed_free_price_id.to_string()),
583            quantity: Some(1),
584            ..Default::default()
585        }]);
586        params.success_url = Some(success_url);
587
588        let session = stripe::CheckoutSession::create(&self.client, params).await?;
589        Ok(session.url.context("no checkout session URL")?)
590    }
591}
592
593#[derive(Serialize)]
594struct DefaultAggregation {
595    formula: &'static str,
596}
597
598#[derive(Serialize)]
599struct StripeCreateMeterParams<'a> {
600    default_aggregation: DefaultAggregation,
601    display_name: String,
602    event_name: &'a str,
603}
604
605#[derive(Clone, Deserialize)]
606struct StripeMeter {
607    id: String,
608    event_name: String,
609}
610
611impl StripeMeter {
612    pub fn create(
613        client: &stripe::Client,
614        params: StripeCreateMeterParams,
615    ) -> stripe::Response<Self> {
616        client.post_form("/billing/meters", params)
617    }
618
619    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
620        #[derive(Serialize)]
621        struct Params {
622            #[serde(skip_serializing_if = "Option::is_none")]
623            limit: Option<u64>,
624        }
625
626        client.get_query("/billing/meters", Params { limit: Some(100) })
627    }
628}
629
630#[derive(Deserialize)]
631struct StripeMeterEvent {
632    identifier: String,
633}
634
635impl StripeMeterEvent {
636    pub async fn create(
637        client: &stripe::Client,
638        params: StripeCreateMeterEventParams<'_>,
639    ) -> Result<Self, stripe::StripeError> {
640        let identifier = params.identifier;
641        match client.post_form("/billing/meter_events", params).await {
642            Ok(event) => Ok(event),
643            Err(stripe::StripeError::Stripe(error)) => {
644                if error.http_status == 400
645                    && error
646                        .message
647                        .as_ref()
648                        .map_or(false, |message| message.contains(identifier))
649                {
650                    Ok(Self {
651                        identifier: identifier.to_string(),
652                    })
653                } else {
654                    Err(stripe::StripeError::Stripe(error))
655                }
656            }
657            Err(error) => Err(error),
658        }
659    }
660}
661
662#[derive(Serialize)]
663struct StripeCreateMeterEventParams<'a> {
664    identifier: &'a str,
665    event_name: &'a str,
666    payload: StripeCreateMeterEventPayload<'a>,
667    timestamp: Option<i64>,
668}
669
670#[derive(Serialize)]
671struct StripeCreateMeterEventPayload<'a> {
672    value: u64,
673    stripe_customer_id: &'a stripe::CustomerId,
674}
675
676fn subscription_contains_price(
677    subscription: &stripe::Subscription,
678    price_id: &stripe::PriceId,
679) -> bool {
680    subscription.items.data.iter().any(|item| {
681        item.price
682            .as_ref()
683            .map_or(false, |price| price.id == *price_id)
684    })
685}