1use std::sync::Arc;
2
3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
4use crate::{Cents, Result};
5use anyhow::{Context as _, anyhow};
6use chrono::{Datelike, Utc};
7use collections::HashMap;
8use serde::{Deserialize, Serialize};
9use stripe::PriceId;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13pub struct StripeBilling {
14 state: RwLock<StripeBillingState>,
15 client: Arc<stripe::Client>,
16}
17
18#[derive(Default)]
19struct StripeBillingState {
20 meters_by_event_name: HashMap<String, StripeMeter>,
21 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
22 prices_by_lookup_key: HashMap<String, stripe::Price>,
23}
24
25pub struct StripeModelTokenPrices {
26 input_tokens_price: StripeBillingPrice,
27 input_cache_creation_tokens_price: StripeBillingPrice,
28 input_cache_read_tokens_price: StripeBillingPrice,
29 output_tokens_price: StripeBillingPrice,
30}
31
32struct StripeBillingPrice {
33 id: stripe::PriceId,
34 meter_event_name: String,
35}
36
37impl StripeBilling {
38 pub fn new(client: Arc<stripe::Client>) -> Self {
39 Self {
40 client,
41 state: RwLock::default(),
42 }
43 }
44
45 pub async fn initialize(&self) -> Result<()> {
46 log::info!("StripeBilling: initializing");
47
48 let mut state = self.state.write().await;
49
50 let (meters, prices) = futures::try_join!(
51 StripeMeter::list(&self.client),
52 stripe::Price::list(
53 &self.client,
54 &stripe::ListPrices {
55 limit: Some(100),
56 ..Default::default()
57 }
58 )
59 )?;
60
61 for meter in meters.data {
62 state
63 .meters_by_event_name
64 .insert(meter.event_name.clone(), meter);
65 }
66
67 for price in prices.data {
68 if let Some(lookup_key) = price.lookup_key.clone() {
69 state.prices_by_lookup_key.insert(lookup_key, price.clone());
70 }
71
72 if let Some(recurring) = price.recurring {
73 if let Some(meter) = recurring.meter {
74 state.price_ids_by_meter_id.insert(meter, price.id);
75 }
76 }
77 }
78
79 log::info!("StripeBilling: initialized");
80
81 Ok(())
82 }
83
84 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
85 self.find_price_id_by_lookup_key("zed-pro").await
86 }
87
88 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
89 self.find_price_id_by_lookup_key("zed-free").await
90 }
91
92 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
93 self.state
94 .read()
95 .await
96 .prices_by_lookup_key
97 .get(lookup_key)
98 .map(|price| price.id.clone())
99 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
100 }
101
102 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
103 self.state
104 .read()
105 .await
106 .prices_by_lookup_key
107 .get(lookup_key)
108 .cloned()
109 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
110 }
111
112 pub async fn register_model_for_token_based_usage(
113 &self,
114 model: &llm::db::model::Model,
115 ) -> Result<StripeModelTokenPrices> {
116 let input_tokens_price = self
117 .get_or_insert_token_price(
118 &format!("model_{}/input_tokens", model.id),
119 &format!("{} (Input Tokens)", model.name),
120 Cents::new(model.price_per_million_input_tokens as u32),
121 )
122 .await?;
123 let input_cache_creation_tokens_price = self
124 .get_or_insert_token_price(
125 &format!("model_{}/input_cache_creation_tokens", model.id),
126 &format!("{} (Input Cache Creation Tokens)", model.name),
127 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
128 )
129 .await?;
130 let input_cache_read_tokens_price = self
131 .get_or_insert_token_price(
132 &format!("model_{}/input_cache_read_tokens", model.id),
133 &format!("{} (Input Cache Read Tokens)", model.name),
134 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
135 )
136 .await?;
137 let output_tokens_price = self
138 .get_or_insert_token_price(
139 &format!("model_{}/output_tokens", model.id),
140 &format!("{} (Output Tokens)", model.name),
141 Cents::new(model.price_per_million_output_tokens as u32),
142 )
143 .await?;
144 Ok(StripeModelTokenPrices {
145 input_tokens_price,
146 input_cache_creation_tokens_price,
147 input_cache_read_tokens_price,
148 output_tokens_price,
149 })
150 }
151
152 async fn get_or_insert_token_price(
153 &self,
154 meter_event_name: &str,
155 price_description: &str,
156 price_per_million_tokens: Cents,
157 ) -> Result<StripeBillingPrice> {
158 // Fast code path when the meter and the price already exist.
159 {
160 let state = self.state.read().await;
161 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
162 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
163 return Ok(StripeBillingPrice {
164 id: price_id.clone(),
165 meter_event_name: meter_event_name.to_string(),
166 });
167 }
168 }
169 }
170
171 let mut state = self.state.write().await;
172 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
173 meter.clone()
174 } else {
175 let meter = StripeMeter::create(
176 &self.client,
177 StripeCreateMeterParams {
178 default_aggregation: DefaultAggregation { formula: "sum" },
179 display_name: price_description.to_string(),
180 event_name: meter_event_name,
181 },
182 )
183 .await?;
184 state
185 .meters_by_event_name
186 .insert(meter_event_name.to_string(), meter.clone());
187 meter
188 };
189
190 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
191 price_id.clone()
192 } else {
193 let price = stripe::Price::create(
194 &self.client,
195 stripe::CreatePrice {
196 active: Some(true),
197 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
198 currency: stripe::Currency::USD,
199 currency_options: None,
200 custom_unit_amount: None,
201 expand: &[],
202 lookup_key: None,
203 metadata: None,
204 nickname: None,
205 product: None,
206 product_data: Some(stripe::CreatePriceProductData {
207 id: None,
208 active: Some(true),
209 metadata: None,
210 name: price_description.to_string(),
211 statement_descriptor: None,
212 tax_code: None,
213 unit_label: None,
214 }),
215 recurring: Some(stripe::CreatePriceRecurring {
216 aggregate_usage: None,
217 interval: stripe::CreatePriceRecurringInterval::Month,
218 interval_count: None,
219 trial_period_days: None,
220 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
221 meter: Some(meter.id.clone()),
222 }),
223 tax_behavior: None,
224 tiers: None,
225 tiers_mode: None,
226 transfer_lookup_key: None,
227 transform_quantity: None,
228 unit_amount: None,
229 unit_amount_decimal: Some(&format!(
230 "{:.12}",
231 price_per_million_tokens.0 as f64 / 1_000_000f64
232 )),
233 },
234 )
235 .await?;
236 state
237 .price_ids_by_meter_id
238 .insert(meter.id, price.id.clone());
239 price.id
240 };
241
242 Ok(StripeBillingPrice {
243 id: price_id,
244 meter_event_name: meter_event_name.to_string(),
245 })
246 }
247
248 pub async fn subscribe_to_price(
249 &self,
250 subscription_id: &stripe::SubscriptionId,
251 price: &stripe::Price,
252 ) -> Result<()> {
253 let subscription =
254 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
255
256 if subscription_contains_price(&subscription, &price.id) {
257 return Ok(());
258 }
259
260 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
261
262 let price_per_unit = price.unit_amount.unwrap_or_default();
263 let units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
264
265 stripe::Subscription::update(
266 &self.client,
267 subscription_id,
268 stripe::UpdateSubscription {
269 items: Some(vec![stripe::UpdateSubscriptionItems {
270 price: Some(price.id.to_string()),
271 billing_thresholds: Some(stripe::SubscriptionItemBillingThresholds {
272 usage_gte: Some(units_for_billing_threshold),
273 }),
274 ..Default::default()
275 }]),
276 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
277 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
278 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
279 },
280 }),
281 ..Default::default()
282 },
283 )
284 .await?;
285
286 Ok(())
287 }
288
289 pub async fn subscribe_to_model(
290 &self,
291 subscription_id: &stripe::SubscriptionId,
292 model: &StripeModelTokenPrices,
293 ) -> Result<()> {
294 let subscription =
295 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
296
297 let mut items = Vec::new();
298
299 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
300 items.push(stripe::UpdateSubscriptionItems {
301 price: Some(model.input_tokens_price.id.to_string()),
302 ..Default::default()
303 });
304 }
305
306 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
307 {
308 items.push(stripe::UpdateSubscriptionItems {
309 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
310 ..Default::default()
311 });
312 }
313
314 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
315 items.push(stripe::UpdateSubscriptionItems {
316 price: Some(model.input_cache_read_tokens_price.id.to_string()),
317 ..Default::default()
318 });
319 }
320
321 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
322 items.push(stripe::UpdateSubscriptionItems {
323 price: Some(model.output_tokens_price.id.to_string()),
324 ..Default::default()
325 });
326 }
327
328 if !items.is_empty() {
329 items.extend(subscription.items.data.iter().map(|item| {
330 stripe::UpdateSubscriptionItems {
331 id: Some(item.id.to_string()),
332 ..Default::default()
333 }
334 }));
335
336 stripe::Subscription::update(
337 &self.client,
338 subscription_id,
339 stripe::UpdateSubscription {
340 items: Some(items),
341 ..Default::default()
342 },
343 )
344 .await?;
345 }
346
347 Ok(())
348 }
349
350 pub async fn bill_model_token_usage(
351 &self,
352 customer_id: &stripe::CustomerId,
353 model: &StripeModelTokenPrices,
354 event: &llm::db::billing_event::Model,
355 ) -> Result<()> {
356 let timestamp = Utc::now().timestamp();
357
358 if event.input_tokens > 0 {
359 StripeMeterEvent::create(
360 &self.client,
361 StripeCreateMeterEventParams {
362 identifier: &format!("input_tokens/{}", event.idempotency_key),
363 event_name: &model.input_tokens_price.meter_event_name,
364 payload: StripeCreateMeterEventPayload {
365 value: event.input_tokens as u64,
366 stripe_customer_id: customer_id,
367 },
368 timestamp: Some(timestamp),
369 },
370 )
371 .await?;
372 }
373
374 if event.input_cache_creation_tokens > 0 {
375 StripeMeterEvent::create(
376 &self.client,
377 StripeCreateMeterEventParams {
378 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
379 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
380 payload: StripeCreateMeterEventPayload {
381 value: event.input_cache_creation_tokens as u64,
382 stripe_customer_id: customer_id,
383 },
384 timestamp: Some(timestamp),
385 },
386 )
387 .await?;
388 }
389
390 if event.input_cache_read_tokens > 0 {
391 StripeMeterEvent::create(
392 &self.client,
393 StripeCreateMeterEventParams {
394 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
395 event_name: &model.input_cache_read_tokens_price.meter_event_name,
396 payload: StripeCreateMeterEventPayload {
397 value: event.input_cache_read_tokens as u64,
398 stripe_customer_id: customer_id,
399 },
400 timestamp: Some(timestamp),
401 },
402 )
403 .await?;
404 }
405
406 if event.output_tokens > 0 {
407 StripeMeterEvent::create(
408 &self.client,
409 StripeCreateMeterEventParams {
410 identifier: &format!("output_tokens/{}", event.idempotency_key),
411 event_name: &model.output_tokens_price.meter_event_name,
412 payload: StripeCreateMeterEventPayload {
413 value: event.output_tokens as u64,
414 stripe_customer_id: customer_id,
415 },
416 timestamp: Some(timestamp),
417 },
418 )
419 .await?;
420 }
421
422 Ok(())
423 }
424
425 pub async fn bill_model_request_usage(
426 &self,
427 customer_id: &stripe::CustomerId,
428 event_name: &str,
429 requests: i32,
430 ) -> Result<()> {
431 let timestamp = Utc::now().timestamp();
432 let idempotency_key = Uuid::new_v4();
433
434 StripeMeterEvent::create(
435 &self.client,
436 StripeCreateMeterEventParams {
437 identifier: &format!("model_requests/{}", idempotency_key),
438 event_name,
439 payload: StripeCreateMeterEventPayload {
440 value: requests as u64,
441 stripe_customer_id: customer_id,
442 },
443 timestamp: Some(timestamp),
444 },
445 )
446 .await?;
447
448 Ok(())
449 }
450
451 pub async fn checkout(
452 &self,
453 customer_id: stripe::CustomerId,
454 github_login: &str,
455 model: &StripeModelTokenPrices,
456 success_url: &str,
457 ) -> Result<String> {
458 let first_of_next_month = Utc::now()
459 .checked_add_months(chrono::Months::new(1))
460 .unwrap()
461 .with_day(1)
462 .unwrap();
463
464 let mut params = stripe::CreateCheckoutSession::new();
465 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
466 params.customer = Some(customer_id);
467 params.client_reference_id = Some(github_login);
468 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
469 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
470 ..Default::default()
471 });
472 params.line_items = Some(
473 [
474 &model.input_tokens_price.id,
475 &model.input_cache_creation_tokens_price.id,
476 &model.input_cache_read_tokens_price.id,
477 &model.output_tokens_price.id,
478 ]
479 .into_iter()
480 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
481 price: Some(price_id.to_string()),
482 ..Default::default()
483 })
484 .collect(),
485 );
486 params.success_url = Some(success_url);
487
488 let session = stripe::CheckoutSession::create(&self.client, params).await?;
489 Ok(session.url.context("no checkout session URL")?)
490 }
491
492 pub async fn checkout_with_zed_pro(
493 &self,
494 customer_id: stripe::CustomerId,
495 github_login: &str,
496 success_url: &str,
497 ) -> Result<String> {
498 let zed_pro_price_id = self.zed_pro_price_id().await?;
499
500 let mut params = stripe::CreateCheckoutSession::new();
501 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
502 params.customer = Some(customer_id);
503 params.client_reference_id = Some(github_login);
504 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
505 price: Some(zed_pro_price_id.to_string()),
506 quantity: Some(1),
507 ..Default::default()
508 }]);
509 params.success_url = Some(success_url);
510
511 let session = stripe::CheckoutSession::create(&self.client, params).await?;
512 Ok(session.url.context("no checkout session URL")?)
513 }
514
515 pub async fn checkout_with_zed_pro_trial(
516 &self,
517 customer_id: stripe::CustomerId,
518 github_login: &str,
519 feature_flags: Vec<String>,
520 success_url: &str,
521 ) -> Result<String> {
522 let zed_pro_price_id = self.zed_pro_price_id().await?;
523
524 let eligible_for_extended_trial = feature_flags
525 .iter()
526 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
527
528 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
529
530 let mut subscription_metadata = std::collections::HashMap::new();
531 if eligible_for_extended_trial {
532 subscription_metadata.insert(
533 "promo_feature_flag".to_string(),
534 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
535 );
536 }
537
538 let mut params = stripe::CreateCheckoutSession::new();
539 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
540 trial_period_days: Some(trial_period_days),
541 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
542 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
543 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
544 }
545 }),
546 metadata: if !subscription_metadata.is_empty() {
547 Some(subscription_metadata)
548 } else {
549 None
550 },
551 ..Default::default()
552 });
553 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
554 params.payment_method_collection =
555 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
556 params.customer = Some(customer_id);
557 params.client_reference_id = Some(github_login);
558 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
559 price: Some(zed_pro_price_id.to_string()),
560 quantity: Some(1),
561 ..Default::default()
562 }]);
563 params.success_url = Some(success_url);
564
565 let session = stripe::CheckoutSession::create(&self.client, params).await?;
566 Ok(session.url.context("no checkout session URL")?)
567 }
568
569 pub async fn checkout_with_zed_free(
570 &self,
571 customer_id: stripe::CustomerId,
572 github_login: &str,
573 success_url: &str,
574 ) -> Result<String> {
575 let zed_free_price_id = self.zed_free_price_id().await?;
576
577 let mut params = stripe::CreateCheckoutSession::new();
578 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
579 params.customer = Some(customer_id);
580 params.client_reference_id = Some(github_login);
581 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
582 price: Some(zed_free_price_id.to_string()),
583 quantity: Some(1),
584 ..Default::default()
585 }]);
586 params.success_url = Some(success_url);
587
588 let session = stripe::CheckoutSession::create(&self.client, params).await?;
589 Ok(session.url.context("no checkout session URL")?)
590 }
591}
592
593#[derive(Serialize)]
594struct DefaultAggregation {
595 formula: &'static str,
596}
597
598#[derive(Serialize)]
599struct StripeCreateMeterParams<'a> {
600 default_aggregation: DefaultAggregation,
601 display_name: String,
602 event_name: &'a str,
603}
604
605#[derive(Clone, Deserialize)]
606struct StripeMeter {
607 id: String,
608 event_name: String,
609}
610
611impl StripeMeter {
612 pub fn create(
613 client: &stripe::Client,
614 params: StripeCreateMeterParams,
615 ) -> stripe::Response<Self> {
616 client.post_form("/billing/meters", params)
617 }
618
619 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
620 #[derive(Serialize)]
621 struct Params {
622 #[serde(skip_serializing_if = "Option::is_none")]
623 limit: Option<u64>,
624 }
625
626 client.get_query("/billing/meters", Params { limit: Some(100) })
627 }
628}
629
630#[derive(Deserialize)]
631struct StripeMeterEvent {
632 identifier: String,
633}
634
635impl StripeMeterEvent {
636 pub async fn create(
637 client: &stripe::Client,
638 params: StripeCreateMeterEventParams<'_>,
639 ) -> Result<Self, stripe::StripeError> {
640 let identifier = params.identifier;
641 match client.post_form("/billing/meter_events", params).await {
642 Ok(event) => Ok(event),
643 Err(stripe::StripeError::Stripe(error)) => {
644 if error.http_status == 400
645 && error
646 .message
647 .as_ref()
648 .map_or(false, |message| message.contains(identifier))
649 {
650 Ok(Self {
651 identifier: identifier.to_string(),
652 })
653 } else {
654 Err(stripe::StripeError::Stripe(error))
655 }
656 }
657 Err(error) => Err(error),
658 }
659 }
660}
661
662#[derive(Serialize)]
663struct StripeCreateMeterEventParams<'a> {
664 identifier: &'a str,
665 event_name: &'a str,
666 payload: StripeCreateMeterEventPayload<'a>,
667 timestamp: Option<i64>,
668}
669
670#[derive(Serialize)]
671struct StripeCreateMeterEventPayload<'a> {
672 value: u64,
673 stripe_customer_id: &'a stripe::CustomerId,
674}
675
676fn subscription_contains_price(
677 subscription: &stripe::Subscription,
678 price_id: &stripe::PriceId,
679) -> bool {
680 subscription.items.data.iter().any(|item| {
681 item.price
682 .as_ref()
683 .map_or(false, |price| price.id == *price_id)
684 })
685}