stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
  4use crate::{Cents, Result};
  5use anyhow::{Context as _, anyhow};
  6use chrono::{Datelike, Utc};
  7use collections::HashMap;
  8use serde::{Deserialize, Serialize};
  9use stripe::PriceId;
 10use tokio::sync::RwLock;
 11use uuid::Uuid;
 12
 13pub struct StripeBilling {
 14    state: RwLock<StripeBillingState>,
 15    client: Arc<stripe::Client>,
 16}
 17
 18#[derive(Default)]
 19struct StripeBillingState {
 20    meters_by_event_name: HashMap<String, StripeMeter>,
 21    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 22    prices_by_lookup_key: HashMap<String, stripe::Price>,
 23}
 24
 25pub struct StripeModelTokenPrices {
 26    input_tokens_price: StripeBillingPrice,
 27    input_cache_creation_tokens_price: StripeBillingPrice,
 28    input_cache_read_tokens_price: StripeBillingPrice,
 29    output_tokens_price: StripeBillingPrice,
 30}
 31
 32struct StripeBillingPrice {
 33    id: stripe::PriceId,
 34    meter_event_name: String,
 35}
 36
 37impl StripeBilling {
 38    pub fn new(client: Arc<stripe::Client>) -> Self {
 39        Self {
 40            client,
 41            state: RwLock::default(),
 42        }
 43    }
 44
 45    pub async fn initialize(&self) -> Result<()> {
 46        log::info!("StripeBilling: initializing");
 47
 48        let mut state = self.state.write().await;
 49
 50        let (meters, prices) = futures::try_join!(
 51            StripeMeter::list(&self.client),
 52            stripe::Price::list(
 53                &self.client,
 54                &stripe::ListPrices {
 55                    limit: Some(100),
 56                    ..Default::default()
 57                }
 58            )
 59        )?;
 60
 61        for meter in meters.data {
 62            state
 63                .meters_by_event_name
 64                .insert(meter.event_name.clone(), meter);
 65        }
 66
 67        for price in prices.data {
 68            if let Some(lookup_key) = price.lookup_key.clone() {
 69                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 70            }
 71
 72            if let Some(recurring) = price.recurring {
 73                if let Some(meter) = recurring.meter {
 74                    state.price_ids_by_meter_id.insert(meter, price.id);
 75                }
 76            }
 77        }
 78
 79        log::info!("StripeBilling: initialized");
 80
 81        Ok(())
 82    }
 83
 84    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
 85        self.state
 86            .read()
 87            .await
 88            .prices_by_lookup_key
 89            .get(lookup_key)
 90            .cloned()
 91            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 92    }
 93
 94    pub async fn register_model_for_token_based_usage(
 95        &self,
 96        model: &llm::db::model::Model,
 97    ) -> Result<StripeModelTokenPrices> {
 98        let input_tokens_price = self
 99            .get_or_insert_token_price(
100                &format!("model_{}/input_tokens", model.id),
101                &format!("{} (Input Tokens)", model.name),
102                Cents::new(model.price_per_million_input_tokens as u32),
103            )
104            .await?;
105        let input_cache_creation_tokens_price = self
106            .get_or_insert_token_price(
107                &format!("model_{}/input_cache_creation_tokens", model.id),
108                &format!("{} (Input Cache Creation Tokens)", model.name),
109                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
110            )
111            .await?;
112        let input_cache_read_tokens_price = self
113            .get_or_insert_token_price(
114                &format!("model_{}/input_cache_read_tokens", model.id),
115                &format!("{} (Input Cache Read Tokens)", model.name),
116                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
117            )
118            .await?;
119        let output_tokens_price = self
120            .get_or_insert_token_price(
121                &format!("model_{}/output_tokens", model.id),
122                &format!("{} (Output Tokens)", model.name),
123                Cents::new(model.price_per_million_output_tokens as u32),
124            )
125            .await?;
126        Ok(StripeModelTokenPrices {
127            input_tokens_price,
128            input_cache_creation_tokens_price,
129            input_cache_read_tokens_price,
130            output_tokens_price,
131        })
132    }
133
134    async fn get_or_insert_token_price(
135        &self,
136        meter_event_name: &str,
137        price_description: &str,
138        price_per_million_tokens: Cents,
139    ) -> Result<StripeBillingPrice> {
140        // Fast code path when the meter and the price already exist.
141        {
142            let state = self.state.read().await;
143            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
144                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
145                    return Ok(StripeBillingPrice {
146                        id: price_id.clone(),
147                        meter_event_name: meter_event_name.to_string(),
148                    });
149                }
150            }
151        }
152
153        let mut state = self.state.write().await;
154        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
155            meter.clone()
156        } else {
157            let meter = StripeMeter::create(
158                &self.client,
159                StripeCreateMeterParams {
160                    default_aggregation: DefaultAggregation { formula: "sum" },
161                    display_name: price_description.to_string(),
162                    event_name: meter_event_name,
163                },
164            )
165            .await?;
166            state
167                .meters_by_event_name
168                .insert(meter_event_name.to_string(), meter.clone());
169            meter
170        };
171
172        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
173            price_id.clone()
174        } else {
175            let price = stripe::Price::create(
176                &self.client,
177                stripe::CreatePrice {
178                    active: Some(true),
179                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
180                    currency: stripe::Currency::USD,
181                    currency_options: None,
182                    custom_unit_amount: None,
183                    expand: &[],
184                    lookup_key: None,
185                    metadata: None,
186                    nickname: None,
187                    product: None,
188                    product_data: Some(stripe::CreatePriceProductData {
189                        id: None,
190                        active: Some(true),
191                        metadata: None,
192                        name: price_description.to_string(),
193                        statement_descriptor: None,
194                        tax_code: None,
195                        unit_label: None,
196                    }),
197                    recurring: Some(stripe::CreatePriceRecurring {
198                        aggregate_usage: None,
199                        interval: stripe::CreatePriceRecurringInterval::Month,
200                        interval_count: None,
201                        trial_period_days: None,
202                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
203                        meter: Some(meter.id.clone()),
204                    }),
205                    tax_behavior: None,
206                    tiers: None,
207                    tiers_mode: None,
208                    transfer_lookup_key: None,
209                    transform_quantity: None,
210                    unit_amount: None,
211                    unit_amount_decimal: Some(&format!(
212                        "{:.12}",
213                        price_per_million_tokens.0 as f64 / 1_000_000f64
214                    )),
215                },
216            )
217            .await?;
218            state
219                .price_ids_by_meter_id
220                .insert(meter.id, price.id.clone());
221            price.id
222        };
223
224        Ok(StripeBillingPrice {
225            id: price_id,
226            meter_event_name: meter_event_name.to_string(),
227        })
228    }
229
230    pub async fn subscribe_to_price(
231        &self,
232        subscription_id: &stripe::SubscriptionId,
233        price_id: &stripe::PriceId,
234    ) -> Result<()> {
235        let subscription =
236            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
237
238        if subscription_contains_price(&subscription, price_id) {
239            return Ok(());
240        }
241
242        stripe::Subscription::update(
243            &self.client,
244            subscription_id,
245            stripe::UpdateSubscription {
246                items: Some(vec![stripe::UpdateSubscriptionItems {
247                    price: Some(price_id.to_string()),
248                    ..Default::default()
249                }]),
250                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
251                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
252                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
253                    },
254                }),
255                ..Default::default()
256            },
257        )
258        .await?;
259
260        Ok(())
261    }
262
263    pub async fn subscribe_to_model(
264        &self,
265        subscription_id: &stripe::SubscriptionId,
266        model: &StripeModelTokenPrices,
267    ) -> Result<()> {
268        let subscription =
269            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
270
271        let mut items = Vec::new();
272
273        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
274            items.push(stripe::UpdateSubscriptionItems {
275                price: Some(model.input_tokens_price.id.to_string()),
276                ..Default::default()
277            });
278        }
279
280        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
281        {
282            items.push(stripe::UpdateSubscriptionItems {
283                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
284                ..Default::default()
285            });
286        }
287
288        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
289            items.push(stripe::UpdateSubscriptionItems {
290                price: Some(model.input_cache_read_tokens_price.id.to_string()),
291                ..Default::default()
292            });
293        }
294
295        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
296            items.push(stripe::UpdateSubscriptionItems {
297                price: Some(model.output_tokens_price.id.to_string()),
298                ..Default::default()
299            });
300        }
301
302        if !items.is_empty() {
303            items.extend(subscription.items.data.iter().map(|item| {
304                stripe::UpdateSubscriptionItems {
305                    id: Some(item.id.to_string()),
306                    ..Default::default()
307                }
308            }));
309
310            stripe::Subscription::update(
311                &self.client,
312                subscription_id,
313                stripe::UpdateSubscription {
314                    items: Some(items),
315                    ..Default::default()
316                },
317            )
318            .await?;
319        }
320
321        Ok(())
322    }
323
324    pub async fn bill_model_token_usage(
325        &self,
326        customer_id: &stripe::CustomerId,
327        model: &StripeModelTokenPrices,
328        event: &llm::db::billing_event::Model,
329    ) -> Result<()> {
330        let timestamp = Utc::now().timestamp();
331
332        if event.input_tokens > 0 {
333            StripeMeterEvent::create(
334                &self.client,
335                StripeCreateMeterEventParams {
336                    identifier: &format!("input_tokens/{}", event.idempotency_key),
337                    event_name: &model.input_tokens_price.meter_event_name,
338                    payload: StripeCreateMeterEventPayload {
339                        value: event.input_tokens as u64,
340                        stripe_customer_id: customer_id,
341                    },
342                    timestamp: Some(timestamp),
343                },
344            )
345            .await?;
346        }
347
348        if event.input_cache_creation_tokens > 0 {
349            StripeMeterEvent::create(
350                &self.client,
351                StripeCreateMeterEventParams {
352                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
353                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
354                    payload: StripeCreateMeterEventPayload {
355                        value: event.input_cache_creation_tokens as u64,
356                        stripe_customer_id: customer_id,
357                    },
358                    timestamp: Some(timestamp),
359                },
360            )
361            .await?;
362        }
363
364        if event.input_cache_read_tokens > 0 {
365            StripeMeterEvent::create(
366                &self.client,
367                StripeCreateMeterEventParams {
368                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
369                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
370                    payload: StripeCreateMeterEventPayload {
371                        value: event.input_cache_read_tokens as u64,
372                        stripe_customer_id: customer_id,
373                    },
374                    timestamp: Some(timestamp),
375                },
376            )
377            .await?;
378        }
379
380        if event.output_tokens > 0 {
381            StripeMeterEvent::create(
382                &self.client,
383                StripeCreateMeterEventParams {
384                    identifier: &format!("output_tokens/{}", event.idempotency_key),
385                    event_name: &model.output_tokens_price.meter_event_name,
386                    payload: StripeCreateMeterEventPayload {
387                        value: event.output_tokens as u64,
388                        stripe_customer_id: customer_id,
389                    },
390                    timestamp: Some(timestamp),
391                },
392            )
393            .await?;
394        }
395
396        Ok(())
397    }
398
399    pub async fn bill_model_request_usage(
400        &self,
401        customer_id: &stripe::CustomerId,
402        event_name: &str,
403        requests: i32,
404    ) -> Result<()> {
405        let timestamp = Utc::now().timestamp();
406        let idempotency_key = Uuid::new_v4();
407
408        StripeMeterEvent::create(
409            &self.client,
410            StripeCreateMeterEventParams {
411                identifier: &format!("model_requests/{}", idempotency_key),
412                event_name,
413                payload: StripeCreateMeterEventPayload {
414                    value: requests as u64,
415                    stripe_customer_id: customer_id,
416                },
417                timestamp: Some(timestamp),
418            },
419        )
420        .await?;
421
422        Ok(())
423    }
424
425    pub async fn checkout(
426        &self,
427        customer_id: stripe::CustomerId,
428        github_login: &str,
429        model: &StripeModelTokenPrices,
430        success_url: &str,
431    ) -> Result<String> {
432        let first_of_next_month = Utc::now()
433            .checked_add_months(chrono::Months::new(1))
434            .unwrap()
435            .with_day(1)
436            .unwrap();
437
438        let mut params = stripe::CreateCheckoutSession::new();
439        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
440        params.customer = Some(customer_id);
441        params.client_reference_id = Some(github_login);
442        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
443            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
444            ..Default::default()
445        });
446        params.line_items = Some(
447            [
448                &model.input_tokens_price.id,
449                &model.input_cache_creation_tokens_price.id,
450                &model.input_cache_read_tokens_price.id,
451                &model.output_tokens_price.id,
452            ]
453            .into_iter()
454            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
455                price: Some(price_id.to_string()),
456                ..Default::default()
457            })
458            .collect(),
459        );
460        params.success_url = Some(success_url);
461
462        let session = stripe::CheckoutSession::create(&self.client, params).await?;
463        Ok(session.url.context("no checkout session URL")?)
464    }
465
466    pub async fn checkout_with_price(
467        &self,
468        price_id: PriceId,
469        customer_id: stripe::CustomerId,
470        github_login: &str,
471        success_url: &str,
472    ) -> Result<String> {
473        let mut params = stripe::CreateCheckoutSession::new();
474        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
475        params.customer = Some(customer_id);
476        params.client_reference_id = Some(github_login);
477        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
478            price: Some(price_id.to_string()),
479            quantity: Some(1),
480            ..Default::default()
481        }]);
482        params.success_url = Some(success_url);
483
484        let session = stripe::CheckoutSession::create(&self.client, params).await?;
485        Ok(session.url.context("no checkout session URL")?)
486    }
487
488    pub async fn checkout_with_zed_pro_trial(
489        &self,
490        zed_pro_price_id: PriceId,
491        customer_id: stripe::CustomerId,
492        github_login: &str,
493        feature_flags: Vec<String>,
494        success_url: &str,
495    ) -> Result<String> {
496        let eligible_for_extended_trial = feature_flags
497            .iter()
498            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
499
500        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
501
502        let mut subscription_metadata = std::collections::HashMap::new();
503        if eligible_for_extended_trial {
504            subscription_metadata.insert(
505                "promo_feature_flag".to_string(),
506                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
507            );
508        }
509
510        let mut params = stripe::CreateCheckoutSession::new();
511        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
512            trial_period_days: Some(trial_period_days),
513            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
514                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
515                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
516                }
517            }),
518            metadata: if !subscription_metadata.is_empty() {
519                Some(subscription_metadata)
520            } else {
521                None
522            },
523            ..Default::default()
524        });
525        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
526        params.payment_method_collection =
527            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
528        params.customer = Some(customer_id);
529        params.client_reference_id = Some(github_login);
530        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
531            price: Some(zed_pro_price_id.to_string()),
532            quantity: Some(1),
533            ..Default::default()
534        }]);
535        params.success_url = Some(success_url);
536
537        let session = stripe::CheckoutSession::create(&self.client, params).await?;
538        Ok(session.url.context("no checkout session URL")?)
539    }
540}
541
542#[derive(Serialize)]
543struct DefaultAggregation {
544    formula: &'static str,
545}
546
547#[derive(Serialize)]
548struct StripeCreateMeterParams<'a> {
549    default_aggregation: DefaultAggregation,
550    display_name: String,
551    event_name: &'a str,
552}
553
554#[derive(Clone, Deserialize)]
555struct StripeMeter {
556    id: String,
557    event_name: String,
558}
559
560impl StripeMeter {
561    pub fn create(
562        client: &stripe::Client,
563        params: StripeCreateMeterParams,
564    ) -> stripe::Response<Self> {
565        client.post_form("/billing/meters", params)
566    }
567
568    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
569        #[derive(Serialize)]
570        struct Params {
571            #[serde(skip_serializing_if = "Option::is_none")]
572            limit: Option<u64>,
573        }
574
575        client.get_query("/billing/meters", Params { limit: Some(100) })
576    }
577}
578
579#[derive(Deserialize)]
580struct StripeMeterEvent {
581    identifier: String,
582}
583
584impl StripeMeterEvent {
585    pub async fn create(
586        client: &stripe::Client,
587        params: StripeCreateMeterEventParams<'_>,
588    ) -> Result<Self, stripe::StripeError> {
589        let identifier = params.identifier;
590        match client.post_form("/billing/meter_events", params).await {
591            Ok(event) => Ok(event),
592            Err(stripe::StripeError::Stripe(error)) => {
593                if error.http_status == 400
594                    && error
595                        .message
596                        .as_ref()
597                        .map_or(false, |message| message.contains(identifier))
598                {
599                    Ok(Self {
600                        identifier: identifier.to_string(),
601                    })
602                } else {
603                    Err(stripe::StripeError::Stripe(error))
604                }
605            }
606            Err(error) => Err(error),
607        }
608    }
609}
610
611#[derive(Serialize)]
612struct StripeCreateMeterEventParams<'a> {
613    identifier: &'a str,
614    event_name: &'a str,
615    payload: StripeCreateMeterEventPayload<'a>,
616    timestamp: Option<i64>,
617}
618
619#[derive(Serialize)]
620struct StripeCreateMeterEventPayload<'a> {
621    value: u64,
622    stripe_customer_id: &'a stripe::CustomerId,
623}
624
625fn subscription_contains_price(
626    subscription: &stripe::Subscription,
627    price_id: &stripe::PriceId,
628) -> bool {
629    subscription.items.data.iter().any(|item| {
630        item.price
631            .as_ref()
632            .map_or(false, |price| price.id == *price_id)
633    })
634}