stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
  4use crate::{Cents, Result};
  5use anyhow::{Context as _, anyhow};
  6use chrono::{Datelike, Utc};
  7use collections::HashMap;
  8use serde::{Deserialize, Serialize};
  9use stripe::PriceId;
 10use tokio::sync::RwLock;
 11use uuid::Uuid;
 12
 13pub struct StripeBilling {
 14    state: RwLock<StripeBillingState>,
 15    client: Arc<stripe::Client>,
 16}
 17
 18#[derive(Default)]
 19struct StripeBillingState {
 20    meters_by_event_name: HashMap<String, StripeMeter>,
 21    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 22    prices_by_lookup_key: HashMap<String, stripe::Price>,
 23}
 24
 25pub struct StripeModelTokenPrices {
 26    input_tokens_price: StripeBillingPrice,
 27    input_cache_creation_tokens_price: StripeBillingPrice,
 28    input_cache_read_tokens_price: StripeBillingPrice,
 29    output_tokens_price: StripeBillingPrice,
 30}
 31
 32struct StripeBillingPrice {
 33    id: stripe::PriceId,
 34    meter_event_name: String,
 35}
 36
 37impl StripeBilling {
 38    pub fn new(client: Arc<stripe::Client>) -> Self {
 39        Self {
 40            client,
 41            state: RwLock::default(),
 42        }
 43    }
 44
 45    pub async fn initialize(&self) -> Result<()> {
 46        log::info!("StripeBilling: initializing");
 47
 48        let mut state = self.state.write().await;
 49
 50        let (meters, prices) = futures::try_join!(
 51            StripeMeter::list(&self.client),
 52            stripe::Price::list(
 53                &self.client,
 54                &stripe::ListPrices {
 55                    limit: Some(100),
 56                    ..Default::default()
 57                }
 58            )
 59        )?;
 60
 61        for meter in meters.data {
 62            state
 63                .meters_by_event_name
 64                .insert(meter.event_name.clone(), meter);
 65        }
 66
 67        for price in prices.data {
 68            if let Some(lookup_key) = price.lookup_key.clone() {
 69                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 70            }
 71
 72            if let Some(recurring) = price.recurring {
 73                if let Some(meter) = recurring.meter {
 74                    state.price_ids_by_meter_id.insert(meter, price.id);
 75                }
 76            }
 77        }
 78
 79        log::info!("StripeBilling: initialized");
 80
 81        Ok(())
 82    }
 83
 84    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 85        self.find_price_id_by_lookup_key("zed-pro").await
 86    }
 87
 88    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 89        self.find_price_id_by_lookup_key("zed-free").await
 90    }
 91
 92    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 93        self.state
 94            .read()
 95            .await
 96            .prices_by_lookup_key
 97            .get(lookup_key)
 98            .map(|price| price.id.clone())
 99            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
100    }
101
102    pub async fn register_model_for_token_based_usage(
103        &self,
104        model: &llm::db::model::Model,
105    ) -> Result<StripeModelTokenPrices> {
106        let input_tokens_price = self
107            .get_or_insert_token_price(
108                &format!("model_{}/input_tokens", model.id),
109                &format!("{} (Input Tokens)", model.name),
110                Cents::new(model.price_per_million_input_tokens as u32),
111            )
112            .await?;
113        let input_cache_creation_tokens_price = self
114            .get_or_insert_token_price(
115                &format!("model_{}/input_cache_creation_tokens", model.id),
116                &format!("{} (Input Cache Creation Tokens)", model.name),
117                Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
118            )
119            .await?;
120        let input_cache_read_tokens_price = self
121            .get_or_insert_token_price(
122                &format!("model_{}/input_cache_read_tokens", model.id),
123                &format!("{} (Input Cache Read Tokens)", model.name),
124                Cents::new(model.price_per_million_cache_read_input_tokens as u32),
125            )
126            .await?;
127        let output_tokens_price = self
128            .get_or_insert_token_price(
129                &format!("model_{}/output_tokens", model.id),
130                &format!("{} (Output Tokens)", model.name),
131                Cents::new(model.price_per_million_output_tokens as u32),
132            )
133            .await?;
134        Ok(StripeModelTokenPrices {
135            input_tokens_price,
136            input_cache_creation_tokens_price,
137            input_cache_read_tokens_price,
138            output_tokens_price,
139        })
140    }
141
142    async fn get_or_insert_token_price(
143        &self,
144        meter_event_name: &str,
145        price_description: &str,
146        price_per_million_tokens: Cents,
147    ) -> Result<StripeBillingPrice> {
148        // Fast code path when the meter and the price already exist.
149        {
150            let state = self.state.read().await;
151            if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
152                if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
153                    return Ok(StripeBillingPrice {
154                        id: price_id.clone(),
155                        meter_event_name: meter_event_name.to_string(),
156                    });
157                }
158            }
159        }
160
161        let mut state = self.state.write().await;
162        let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
163            meter.clone()
164        } else {
165            let meter = StripeMeter::create(
166                &self.client,
167                StripeCreateMeterParams {
168                    default_aggregation: DefaultAggregation { formula: "sum" },
169                    display_name: price_description.to_string(),
170                    event_name: meter_event_name,
171                },
172            )
173            .await?;
174            state
175                .meters_by_event_name
176                .insert(meter_event_name.to_string(), meter.clone());
177            meter
178        };
179
180        let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
181            price_id.clone()
182        } else {
183            let price = stripe::Price::create(
184                &self.client,
185                stripe::CreatePrice {
186                    active: Some(true),
187                    billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
188                    currency: stripe::Currency::USD,
189                    currency_options: None,
190                    custom_unit_amount: None,
191                    expand: &[],
192                    lookup_key: None,
193                    metadata: None,
194                    nickname: None,
195                    product: None,
196                    product_data: Some(stripe::CreatePriceProductData {
197                        id: None,
198                        active: Some(true),
199                        metadata: None,
200                        name: price_description.to_string(),
201                        statement_descriptor: None,
202                        tax_code: None,
203                        unit_label: None,
204                    }),
205                    recurring: Some(stripe::CreatePriceRecurring {
206                        aggregate_usage: None,
207                        interval: stripe::CreatePriceRecurringInterval::Month,
208                        interval_count: None,
209                        trial_period_days: None,
210                        usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
211                        meter: Some(meter.id.clone()),
212                    }),
213                    tax_behavior: None,
214                    tiers: None,
215                    tiers_mode: None,
216                    transfer_lookup_key: None,
217                    transform_quantity: None,
218                    unit_amount: None,
219                    unit_amount_decimal: Some(&format!(
220                        "{:.12}",
221                        price_per_million_tokens.0 as f64 / 1_000_000f64
222                    )),
223                },
224            )
225            .await?;
226            state
227                .price_ids_by_meter_id
228                .insert(meter.id, price.id.clone());
229            price.id
230        };
231
232        Ok(StripeBillingPrice {
233            id: price_id,
234            meter_event_name: meter_event_name.to_string(),
235        })
236    }
237
238    pub async fn subscribe_to_price(
239        &self,
240        subscription_id: &stripe::SubscriptionId,
241        price_id: &stripe::PriceId,
242    ) -> Result<()> {
243        let subscription =
244            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
245
246        if subscription_contains_price(&subscription, price_id) {
247            return Ok(());
248        }
249
250        stripe::Subscription::update(
251            &self.client,
252            subscription_id,
253            stripe::UpdateSubscription {
254                items: Some(vec![stripe::UpdateSubscriptionItems {
255                    price: Some(price_id.to_string()),
256                    ..Default::default()
257                }]),
258                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
259                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
260                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
261                    },
262                }),
263                ..Default::default()
264            },
265        )
266        .await?;
267
268        Ok(())
269    }
270
271    pub async fn subscribe_to_model(
272        &self,
273        subscription_id: &stripe::SubscriptionId,
274        model: &StripeModelTokenPrices,
275    ) -> Result<()> {
276        let subscription =
277            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
278
279        let mut items = Vec::new();
280
281        if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
282            items.push(stripe::UpdateSubscriptionItems {
283                price: Some(model.input_tokens_price.id.to_string()),
284                ..Default::default()
285            });
286        }
287
288        if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
289        {
290            items.push(stripe::UpdateSubscriptionItems {
291                price: Some(model.input_cache_creation_tokens_price.id.to_string()),
292                ..Default::default()
293            });
294        }
295
296        if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
297            items.push(stripe::UpdateSubscriptionItems {
298                price: Some(model.input_cache_read_tokens_price.id.to_string()),
299                ..Default::default()
300            });
301        }
302
303        if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
304            items.push(stripe::UpdateSubscriptionItems {
305                price: Some(model.output_tokens_price.id.to_string()),
306                ..Default::default()
307            });
308        }
309
310        if !items.is_empty() {
311            items.extend(subscription.items.data.iter().map(|item| {
312                stripe::UpdateSubscriptionItems {
313                    id: Some(item.id.to_string()),
314                    ..Default::default()
315                }
316            }));
317
318            stripe::Subscription::update(
319                &self.client,
320                subscription_id,
321                stripe::UpdateSubscription {
322                    items: Some(items),
323                    ..Default::default()
324                },
325            )
326            .await?;
327        }
328
329        Ok(())
330    }
331
332    pub async fn bill_model_token_usage(
333        &self,
334        customer_id: &stripe::CustomerId,
335        model: &StripeModelTokenPrices,
336        event: &llm::db::billing_event::Model,
337    ) -> Result<()> {
338        let timestamp = Utc::now().timestamp();
339
340        if event.input_tokens > 0 {
341            StripeMeterEvent::create(
342                &self.client,
343                StripeCreateMeterEventParams {
344                    identifier: &format!("input_tokens/{}", event.idempotency_key),
345                    event_name: &model.input_tokens_price.meter_event_name,
346                    payload: StripeCreateMeterEventPayload {
347                        value: event.input_tokens as u64,
348                        stripe_customer_id: customer_id,
349                    },
350                    timestamp: Some(timestamp),
351                },
352            )
353            .await?;
354        }
355
356        if event.input_cache_creation_tokens > 0 {
357            StripeMeterEvent::create(
358                &self.client,
359                StripeCreateMeterEventParams {
360                    identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
361                    event_name: &model.input_cache_creation_tokens_price.meter_event_name,
362                    payload: StripeCreateMeterEventPayload {
363                        value: event.input_cache_creation_tokens as u64,
364                        stripe_customer_id: customer_id,
365                    },
366                    timestamp: Some(timestamp),
367                },
368            )
369            .await?;
370        }
371
372        if event.input_cache_read_tokens > 0 {
373            StripeMeterEvent::create(
374                &self.client,
375                StripeCreateMeterEventParams {
376                    identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
377                    event_name: &model.input_cache_read_tokens_price.meter_event_name,
378                    payload: StripeCreateMeterEventPayload {
379                        value: event.input_cache_read_tokens as u64,
380                        stripe_customer_id: customer_id,
381                    },
382                    timestamp: Some(timestamp),
383                },
384            )
385            .await?;
386        }
387
388        if event.output_tokens > 0 {
389            StripeMeterEvent::create(
390                &self.client,
391                StripeCreateMeterEventParams {
392                    identifier: &format!("output_tokens/{}", event.idempotency_key),
393                    event_name: &model.output_tokens_price.meter_event_name,
394                    payload: StripeCreateMeterEventPayload {
395                        value: event.output_tokens as u64,
396                        stripe_customer_id: customer_id,
397                    },
398                    timestamp: Some(timestamp),
399                },
400            )
401            .await?;
402        }
403
404        Ok(())
405    }
406
407    pub async fn bill_model_request_usage(
408        &self,
409        customer_id: &stripe::CustomerId,
410        event_name: &str,
411        requests: i32,
412    ) -> Result<()> {
413        let timestamp = Utc::now().timestamp();
414        let idempotency_key = Uuid::new_v4();
415
416        StripeMeterEvent::create(
417            &self.client,
418            StripeCreateMeterEventParams {
419                identifier: &format!("model_requests/{}", idempotency_key),
420                event_name,
421                payload: StripeCreateMeterEventPayload {
422                    value: requests as u64,
423                    stripe_customer_id: customer_id,
424                },
425                timestamp: Some(timestamp),
426            },
427        )
428        .await?;
429
430        Ok(())
431    }
432
433    pub async fn checkout(
434        &self,
435        customer_id: stripe::CustomerId,
436        github_login: &str,
437        model: &StripeModelTokenPrices,
438        success_url: &str,
439    ) -> Result<String> {
440        let first_of_next_month = Utc::now()
441            .checked_add_months(chrono::Months::new(1))
442            .unwrap()
443            .with_day(1)
444            .unwrap();
445
446        let mut params = stripe::CreateCheckoutSession::new();
447        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
448        params.customer = Some(customer_id);
449        params.client_reference_id = Some(github_login);
450        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
451            billing_cycle_anchor: Some(first_of_next_month.timestamp()),
452            ..Default::default()
453        });
454        params.line_items = Some(
455            [
456                &model.input_tokens_price.id,
457                &model.input_cache_creation_tokens_price.id,
458                &model.input_cache_read_tokens_price.id,
459                &model.output_tokens_price.id,
460            ]
461            .into_iter()
462            .map(|price_id| stripe::CreateCheckoutSessionLineItems {
463                price: Some(price_id.to_string()),
464                ..Default::default()
465            })
466            .collect(),
467        );
468        params.success_url = Some(success_url);
469
470        let session = stripe::CheckoutSession::create(&self.client, params).await?;
471        Ok(session.url.context("no checkout session URL")?)
472    }
473
474    pub async fn checkout_with_zed_pro(
475        &self,
476        customer_id: stripe::CustomerId,
477        github_login: &str,
478        success_url: &str,
479    ) -> Result<String> {
480        let zed_pro_price_id = self.zed_pro_price_id().await?;
481
482        let mut params = stripe::CreateCheckoutSession::new();
483        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
484        params.customer = Some(customer_id);
485        params.client_reference_id = Some(github_login);
486        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
487            price: Some(zed_pro_price_id.to_string()),
488            quantity: Some(1),
489            ..Default::default()
490        }]);
491        params.success_url = Some(success_url);
492
493        let session = stripe::CheckoutSession::create(&self.client, params).await?;
494        Ok(session.url.context("no checkout session URL")?)
495    }
496
497    pub async fn checkout_with_zed_pro_trial(
498        &self,
499        customer_id: stripe::CustomerId,
500        github_login: &str,
501        feature_flags: Vec<String>,
502        success_url: &str,
503    ) -> Result<String> {
504        let zed_pro_price_id = self.zed_pro_price_id().await?;
505
506        let eligible_for_extended_trial = feature_flags
507            .iter()
508            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
509
510        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
511
512        let mut subscription_metadata = std::collections::HashMap::new();
513        if eligible_for_extended_trial {
514            subscription_metadata.insert(
515                "promo_feature_flag".to_string(),
516                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
517            );
518        }
519
520        let mut params = stripe::CreateCheckoutSession::new();
521        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
522            trial_period_days: Some(trial_period_days),
523            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
524                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
525                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
526                }
527            }),
528            metadata: if !subscription_metadata.is_empty() {
529                Some(subscription_metadata)
530            } else {
531                None
532            },
533            ..Default::default()
534        });
535        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
536        params.payment_method_collection =
537            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
538        params.customer = Some(customer_id);
539        params.client_reference_id = Some(github_login);
540        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
541            price: Some(zed_pro_price_id.to_string()),
542            quantity: Some(1),
543            ..Default::default()
544        }]);
545        params.success_url = Some(success_url);
546
547        let session = stripe::CheckoutSession::create(&self.client, params).await?;
548        Ok(session.url.context("no checkout session URL")?)
549    }
550}
551
552#[derive(Serialize)]
553struct DefaultAggregation {
554    formula: &'static str,
555}
556
557#[derive(Serialize)]
558struct StripeCreateMeterParams<'a> {
559    default_aggregation: DefaultAggregation,
560    display_name: String,
561    event_name: &'a str,
562}
563
564#[derive(Clone, Deserialize)]
565struct StripeMeter {
566    id: String,
567    event_name: String,
568}
569
570impl StripeMeter {
571    pub fn create(
572        client: &stripe::Client,
573        params: StripeCreateMeterParams,
574    ) -> stripe::Response<Self> {
575        client.post_form("/billing/meters", params)
576    }
577
578    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
579        #[derive(Serialize)]
580        struct Params {
581            #[serde(skip_serializing_if = "Option::is_none")]
582            limit: Option<u64>,
583        }
584
585        client.get_query("/billing/meters", Params { limit: Some(100) })
586    }
587}
588
589#[derive(Deserialize)]
590struct StripeMeterEvent {
591    identifier: String,
592}
593
594impl StripeMeterEvent {
595    pub async fn create(
596        client: &stripe::Client,
597        params: StripeCreateMeterEventParams<'_>,
598    ) -> Result<Self, stripe::StripeError> {
599        let identifier = params.identifier;
600        match client.post_form("/billing/meter_events", params).await {
601            Ok(event) => Ok(event),
602            Err(stripe::StripeError::Stripe(error)) => {
603                if error.http_status == 400
604                    && error
605                        .message
606                        .as_ref()
607                        .map_or(false, |message| message.contains(identifier))
608                {
609                    Ok(Self {
610                        identifier: identifier.to_string(),
611                    })
612                } else {
613                    Err(stripe::StripeError::Stripe(error))
614                }
615            }
616            Err(error) => Err(error),
617        }
618    }
619}
620
621#[derive(Serialize)]
622struct StripeCreateMeterEventParams<'a> {
623    identifier: &'a str,
624    event_name: &'a str,
625    payload: StripeCreateMeterEventPayload<'a>,
626    timestamp: Option<i64>,
627}
628
629#[derive(Serialize)]
630struct StripeCreateMeterEventPayload<'a> {
631    value: u64,
632    stripe_customer_id: &'a stripe::CustomerId,
633}
634
635fn subscription_contains_price(
636    subscription: &stripe::Subscription,
637    price_id: &stripe::PriceId,
638) -> bool {
639    subscription.items.data.iter().any(|item| {
640        item.price
641            .as_ref()
642            .map_or(false, |price| price.id == *price_id)
643    })
644}