stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
  4use crate::{Cents, Result};
  5use anyhow::{Context as _, anyhow};
  6use chrono::{Datelike, Utc};
  7use collections::HashMap;
  8use serde::{Deserialize, Serialize};
  9use stripe::PriceId;
 10use tokio::sync::RwLock;
 11use uuid::Uuid;
 12
 13pub struct StripeBilling {
 14    state: RwLock<StripeBillingState>,
 15    client: Arc<stripe::Client>,
 16}
 17
 18#[derive(Default)]
 19struct StripeBillingState {
 20    meters_by_event_name: HashMap<String, StripeMeter>,
 21    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 22    prices_by_lookup_key: HashMap<String, stripe::Price>,
 23}
 24
 25pub struct StripeModelTokenPrices {
 26    input_tokens_price: StripeBillingPrice,
 27    input_cache_creation_tokens_price: StripeBillingPrice,
 28    input_cache_read_tokens_price: StripeBillingPrice,
 29    output_tokens_price: StripeBillingPrice,
 30}
 31
 32struct StripeBillingPrice {
 33    id: stripe::PriceId,
 34    meter_event_name: String,
 35}
 36
 37impl StripeBilling {
 38    pub fn new(client: Arc<stripe::Client>) -> Self {
 39        Self {
 40            client,
 41            state: RwLock::default(),
 42        }
 43    }
 44
 45    pub async fn initialize(&self) -> Result<()> {
 46        log::info!("StripeBilling: initializing");
 47
 48        let mut state = self.state.write().await;
 49
 50        let (meters, prices) = futures::try_join!(
 51            StripeMeter::list(&self.client),
 52            stripe::Price::list(
 53                &self.client,
 54                &stripe::ListPrices {
 55                    limit: Some(100),
 56                    ..Default::default()
 57                }
 58            )
 59        )?;
 60
 61        for meter in meters.data {
 62            state
 63                .meters_by_event_name
 64                .insert(meter.event_name.clone(), meter);
 65        }
 66
 67        for price in prices.data {
 68            if let Some(lookup_key) = price.lookup_key.clone() {
 69                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 70            }
 71
 72            if let Some(recurring) = price.recurring {
 73                if let Some(meter) = recurring.meter {
 74                    state.price_ids_by_meter_id.insert(meter, price.id);
 75                }
 76            }
 77        }
 78
 79        log::info!("StripeBilling: initialized");
 80
 81        Ok(())
 82    }
 83
 84    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 85        self.find_price_id_by_lookup_key("zed-pro").await
 86    }
 87
 88    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 89        self.find_price_id_by_lookup_key("zed-free").await
 90    }
 91
 92    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 93        self.state
 94            .read()
 95            .await
 96            .prices_by_lookup_key
 97            .get(lookup_key)
 98            .map(|price| price.id.clone())
 99            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
100    }
101
102    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
103        self.state
104            .read()
105            .await
106            .prices_by_lookup_key
107            .get(lookup_key)
108            .cloned()
109            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
110    }
111
112    pub async fn register_model_for_token_based_usage(
113        &self,
114        model: &llm::db::model::Model,
115    ) -> Result<StripeModelTokenPrices> {
116        let input_tokens_price = self
117            .get_or_insert_token_price(
118                &format!("model_{}/input_tokens", model.id),
119                &format!("{} (Input Tokens)", model.name),
120                Cents::new(model.price_per_million_input_tokens as u32),
121            )
122            .await?;
123        let input_cache_creation_tokens_price = self
124            .get_or_insert_token_price(
125                &format!("model_{}/input_cache_creation_tokens", model.id),
126                &format!("{} (Input Cache Creation Tokens)", model.name),
127                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
128            )
129            .await?;
130        let input_cache_read_tokens_price = self
131            .get_or_insert_token_price(
132                &format!("model_{}/input_cache_read_tokens", model.id),
133                &format!("{} (Input Cache Read Tokens)", model.name),
134                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
135            )
136            .await?;
137        let output_tokens_price = self
138            .get_or_insert_token_price(
139                &format!("model_{}/output_tokens", model.id),
140                &format!("{} (Output Tokens)", model.name),
141                Cents::new(model.price_per_million_output_tokens as u32),
142            )
143            .await?;
144        Ok(StripeModelTokenPrices {
145            input_tokens_price,
146            input_cache_creation_tokens_price,
147            input_cache_read_tokens_price,
148            output_tokens_price,
149        })
150    }
151
152    async fn get_or_insert_token_price(
153        &self,
154        meter_event_name: &str,
155        price_description: &str,
156        price_per_million_tokens: Cents,
157    ) -> Result<StripeBillingPrice> {
158        // Fast code path when the meter and the price already exist.
159        {
160            let state = self.state.read().await;
161            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
162                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
163                    return Ok(StripeBillingPrice {
164                        id: price_id.clone(),
165                        meter_event_name: meter_event_name.to_string(),
166                    });
167                }
168            }
169        }
170
171        let mut state = self.state.write().await;
172        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
173            meter.clone()
174        } else {
175            let meter = StripeMeter::create(
176                &self.client,
177                StripeCreateMeterParams {
178                    default_aggregation: DefaultAggregation { formula: "sum" },
179                    display_name: price_description.to_string(),
180                    event_name: meter_event_name,
181                },
182            )
183            .await?;
184            state
185                .meters_by_event_name
186                .insert(meter_event_name.to_string(), meter.clone());
187            meter
188        };
189
190        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
191            price_id.clone()
192        } else {
193            let price = stripe::Price::create(
194                &self.client,
195                stripe::CreatePrice {
196                    active: Some(true),
197                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
198                    currency: stripe::Currency::USD,
199                    currency_options: None,
200                    custom_unit_amount: None,
201                    expand: &[],
202                    lookup_key: None,
203                    metadata: None,
204                    nickname: None,
205                    product: None,
206                    product_data: Some(stripe::CreatePriceProductData {
207                        id: None,
208                        active: Some(true),
209                        metadata: None,
210                        name: price_description.to_string(),
211                        statement_descriptor: None,
212                        tax_code: None,
213                        unit_label: None,
214                    }),
215                    recurring: Some(stripe::CreatePriceRecurring {
216                        aggregate_usage: None,
217                        interval: stripe::CreatePriceRecurringInterval::Month,
218                        interval_count: None,
219                        trial_period_days: None,
220                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
221                        meter: Some(meter.id.clone()),
222                    }),
223                    tax_behavior: None,
224                    tiers: None,
225                    tiers_mode: None,
226                    transfer_lookup_key: None,
227                    transform_quantity: None,
228                    unit_amount: None,
229                    unit_amount_decimal: Some(&format!(
230                        "{:.12}",
231                        price_per_million_tokens.0 as f64 / 1_000_000f64
232                    )),
233                },
234            )
235            .await?;
236            state
237                .price_ids_by_meter_id
238                .insert(meter.id, price.id.clone());
239            price.id
240        };
241
242        Ok(StripeBillingPrice {
243            id: price_id,
244            meter_event_name: meter_event_name.to_string(),
245        })
246    }
247
248    pub async fn subscribe_to_price(
249        &self,
250        subscription_id: &stripe::SubscriptionId,
251        price: &stripe::Price,
252    ) -> Result<()> {
253        let subscription =
254            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
255
256        if subscription_contains_price(&subscription, &price.id) {
257            return Ok(());
258        }
259
260        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
261
262        let price_per_unit = price.unit_amount.unwrap_or_default();
263        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
264
265        stripe::Subscription::update(
266            &self.client,
267            subscription_id,
268            stripe::UpdateSubscription {
269                items: Some(vec![stripe::UpdateSubscriptionItems {
270                    price: Some(price.id.to_string()),
271                    ..Default::default()
272                }]),
273                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
274                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
275                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
276                    },
277                }),
278                ..Default::default()
279            },
280        )
281        .await?;
282
283        Ok(())
284    }
285
286    pub async fn subscribe_to_model(
287        &self,
288        subscription_id: &stripe::SubscriptionId,
289        model: &StripeModelTokenPrices,
290    ) -> Result<()> {
291        let subscription =
292            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
293
294        let mut items = Vec::new();
295
296        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
297            items.push(stripe::UpdateSubscriptionItems {
298                price: Some(model.input_tokens_price.id.to_string()),
299                ..Default::default()
300            });
301        }
302
303        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
304        {
305            items.push(stripe::UpdateSubscriptionItems {
306                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
307                ..Default::default()
308            });
309        }
310
311        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
312            items.push(stripe::UpdateSubscriptionItems {
313                price: Some(model.input_cache_read_tokens_price.id.to_string()),
314                ..Default::default()
315            });
316        }
317
318        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
319            items.push(stripe::UpdateSubscriptionItems {
320                price: Some(model.output_tokens_price.id.to_string()),
321                ..Default::default()
322            });
323        }
324
325        if !items.is_empty() {
326            items.extend(subscription.items.data.iter().map(|item| {
327                stripe::UpdateSubscriptionItems {
328                    id: Some(item.id.to_string()),
329                    ..Default::default()
330                }
331            }));
332
333            stripe::Subscription::update(
334                &self.client,
335                subscription_id,
336                stripe::UpdateSubscription {
337                    items: Some(items),
338                    ..Default::default()
339                },
340            )
341            .await?;
342        }
343
344        Ok(())
345    }
346
347    pub async fn bill_model_token_usage(
348        &self,
349        customer_id: &stripe::CustomerId,
350        model: &StripeModelTokenPrices,
351        event: &llm::db::billing_event::Model,
352    ) -> Result<()> {
353        let timestamp = Utc::now().timestamp();
354
355        if event.input_tokens > 0 {
356            StripeMeterEvent::create(
357                &self.client,
358                StripeCreateMeterEventParams {
359                    identifier: &format!("input_tokens/{}", event.idempotency_key),
360                    event_name: &model.input_tokens_price.meter_event_name,
361                    payload: StripeCreateMeterEventPayload {
362                        value: event.input_tokens as u64,
363                        stripe_customer_id: customer_id,
364                    },
365                    timestamp: Some(timestamp),
366                },
367            )
368            .await?;
369        }
370
371        if event.input_cache_creation_tokens > 0 {
372            StripeMeterEvent::create(
373                &self.client,
374                StripeCreateMeterEventParams {
375                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
376                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
377                    payload: StripeCreateMeterEventPayload {
378                        value: event.input_cache_creation_tokens as u64,
379                        stripe_customer_id: customer_id,
380                    },
381                    timestamp: Some(timestamp),
382                },
383            )
384            .await?;
385        }
386
387        if event.input_cache_read_tokens > 0 {
388            StripeMeterEvent::create(
389                &self.client,
390                StripeCreateMeterEventParams {
391                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
392                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
393                    payload: StripeCreateMeterEventPayload {
394                        value: event.input_cache_read_tokens as u64,
395                        stripe_customer_id: customer_id,
396                    },
397                    timestamp: Some(timestamp),
398                },
399            )
400            .await?;
401        }
402
403        if event.output_tokens > 0 {
404            StripeMeterEvent::create(
405                &self.client,
406                StripeCreateMeterEventParams {
407                    identifier: &format!("output_tokens/{}", event.idempotency_key),
408                    event_name: &model.output_tokens_price.meter_event_name,
409                    payload: StripeCreateMeterEventPayload {
410                        value: event.output_tokens as u64,
411                        stripe_customer_id: customer_id,
412                    },
413                    timestamp: Some(timestamp),
414                },
415            )
416            .await?;
417        }
418
419        Ok(())
420    }
421
422    pub async fn bill_model_request_usage(
423        &self,
424        customer_id: &stripe::CustomerId,
425        event_name: &str,
426        requests: i32,
427    ) -> Result<()> {
428        let timestamp = Utc::now().timestamp();
429        let idempotency_key = Uuid::new_v4();
430
431        StripeMeterEvent::create(
432            &self.client,
433            StripeCreateMeterEventParams {
434                identifier: &format!("model_requests/{}", idempotency_key),
435                event_name,
436                payload: StripeCreateMeterEventPayload {
437                    value: requests as u64,
438                    stripe_customer_id: customer_id,
439                },
440                timestamp: Some(timestamp),
441            },
442        )
443        .await?;
444
445        Ok(())
446    }
447
448    pub async fn checkout(
449        &self,
450        customer_id: stripe::CustomerId,
451        github_login: &str,
452        model: &StripeModelTokenPrices,
453        success_url: &str,
454    ) -> Result<String> {
455        let first_of_next_month = Utc::now()
456            .checked_add_months(chrono::Months::new(1))
457            .unwrap()
458            .with_day(1)
459            .unwrap();
460
461        let mut params = stripe::CreateCheckoutSession::new();
462        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
463        params.customer = Some(customer_id);
464        params.client_reference_id = Some(github_login);
465        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
466            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
467            ..Default::default()
468        });
469        params.line_items = Some(
470            [
471                &model.input_tokens_price.id,
472                &model.input_cache_creation_tokens_price.id,
473                &model.input_cache_read_tokens_price.id,
474                &model.output_tokens_price.id,
475            ]
476            .into_iter()
477            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
478                price: Some(price_id.to_string()),
479                ..Default::default()
480            })
481            .collect(),
482        );
483        params.success_url = Some(success_url);
484
485        let session = stripe::CheckoutSession::create(&self.client, params).await?;
486        Ok(session.url.context("no checkout session URL")?)
487    }
488
489    pub async fn checkout_with_zed_pro(
490        &self,
491        customer_id: stripe::CustomerId,
492        github_login: &str,
493        success_url: &str,
494    ) -> Result<String> {
495        let zed_pro_price_id = self.zed_pro_price_id().await?;
496
497        let mut params = stripe::CreateCheckoutSession::new();
498        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
499        params.customer = Some(customer_id);
500        params.client_reference_id = Some(github_login);
501        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
502            price: Some(zed_pro_price_id.to_string()),
503            quantity: Some(1),
504            ..Default::default()
505        }]);
506        params.success_url = Some(success_url);
507
508        let session = stripe::CheckoutSession::create(&self.client, params).await?;
509        Ok(session.url.context("no checkout session URL")?)
510    }
511
512    pub async fn checkout_with_zed_pro_trial(
513        &self,
514        customer_id: stripe::CustomerId,
515        github_login: &str,
516        feature_flags: Vec<String>,
517        success_url: &str,
518    ) -> Result<String> {
519        let zed_pro_price_id = self.zed_pro_price_id().await?;
520
521        let eligible_for_extended_trial = feature_flags
522            .iter()
523            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
524
525        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
526
527        let mut subscription_metadata = std::collections::HashMap::new();
528        if eligible_for_extended_trial {
529            subscription_metadata.insert(
530                "promo_feature_flag".to_string(),
531                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
532            );
533        }
534
535        let mut params = stripe::CreateCheckoutSession::new();
536        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
537            trial_period_days: Some(trial_period_days),
538            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
539                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
540                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
541                }
542            }),
543            metadata: if !subscription_metadata.is_empty() {
544                Some(subscription_metadata)
545            } else {
546                None
547            },
548            ..Default::default()
549        });
550        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
551        params.payment_method_collection =
552            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
553        params.customer = Some(customer_id);
554        params.client_reference_id = Some(github_login);
555        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
556            price: Some(zed_pro_price_id.to_string()),
557            quantity: Some(1),
558            ..Default::default()
559        }]);
560        params.success_url = Some(success_url);
561
562        let session = stripe::CheckoutSession::create(&self.client, params).await?;
563        Ok(session.url.context("no checkout session URL")?)
564    }
565
566    pub async fn checkout_with_zed_free(
567        &self,
568        customer_id: stripe::CustomerId,
569        github_login: &str,
570        success_url: &str,
571    ) -> Result<String> {
572        let zed_free_price_id = self.zed_free_price_id().await?;
573
574        let mut params = stripe::CreateCheckoutSession::new();
575        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
576        params.customer = Some(customer_id);
577        params.client_reference_id = Some(github_login);
578        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
579            price: Some(zed_free_price_id.to_string()),
580            quantity: Some(1),
581            ..Default::default()
582        }]);
583        params.success_url = Some(success_url);
584
585        let session = stripe::CheckoutSession::create(&self.client, params).await?;
586        Ok(session.url.context("no checkout session URL")?)
587    }
588}
589
590#[derive(Serialize)]
591struct DefaultAggregation {
592    formula: &'static str,
593}
594
595#[derive(Serialize)]
596struct StripeCreateMeterParams<'a> {
597    default_aggregation: DefaultAggregation,
598    display_name: String,
599    event_name: &'a str,
600}
601
602#[derive(Clone, Deserialize)]
603struct StripeMeter {
604    id: String,
605    event_name: String,
606}
607
608impl StripeMeter {
609    pub fn create(
610        client: &stripe::Client,
611        params: StripeCreateMeterParams,
612    ) -> stripe::Response<Self> {
613        client.post_form("/billing/meters", params)
614    }
615
616    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
617        #[derive(Serialize)]
618        struct Params {
619            #[serde(skip_serializing_if = "Option::is_none")]
620            limit: Option<u64>,
621        }
622
623        client.get_query("/billing/meters", Params { limit: Some(100) })
624    }
625}
626
627#[derive(Deserialize)]
628struct StripeMeterEvent {
629    identifier: String,
630}
631
632impl StripeMeterEvent {
633    pub async fn create(
634        client: &stripe::Client,
635        params: StripeCreateMeterEventParams<'_>,
636    ) -> Result<Self, stripe::StripeError> {
637        let identifier = params.identifier;
638        match client.post_form("/billing/meter_events", params).await {
639            Ok(event) => Ok(event),
640            Err(stripe::StripeError::Stripe(error)) => {
641                if error.http_status == 400
642                    && error
643                        .message
644                        .as_ref()
645                        .map_or(false, |message| message.contains(identifier))
646                {
647                    Ok(Self {
648                        identifier: identifier.to_string(),
649                    })
650                } else {
651                    Err(stripe::StripeError::Stripe(error))
652                }
653            }
654            Err(error) => Err(error),
655        }
656    }
657}
658
659#[derive(Serialize)]
660struct StripeCreateMeterEventParams<'a> {
661    identifier: &'a str,
662    event_name: &'a str,
663    payload: StripeCreateMeterEventPayload<'a>,
664    timestamp: Option<i64>,
665}
666
667#[derive(Serialize)]
668struct StripeCreateMeterEventPayload<'a> {
669    value: u64,
670    stripe_customer_id: &'a stripe::CustomerId,
671}
672
673fn subscription_contains_price(
674    subscription: &stripe::Subscription,
675    price_id: &stripe::PriceId,
676) -> bool {
677    subscription.items.data.iter().any(|item| {
678        item.price
679            .as_ref()
680            .map_or(false, |price| price.id == *price_id)
681    })
682}