1use std::sync::Arc;
2
3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
4use crate::{Cents, Result};
5use anyhow::{Context as _, anyhow};
6use chrono::{Datelike, Utc};
7use collections::HashMap;
8use serde::{Deserialize, Serialize};
9use stripe::PriceId;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13pub struct StripeBilling {
14 state: RwLock<StripeBillingState>,
15 client: Arc<stripe::Client>,
16}
17
18#[derive(Default)]
19struct StripeBillingState {
20 meters_by_event_name: HashMap<String, StripeMeter>,
21 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
22 prices_by_lookup_key: HashMap<String, stripe::Price>,
23}
24
25pub struct StripeModelTokenPrices {
26 input_tokens_price: StripeBillingPrice,
27 input_cache_creation_tokens_price: StripeBillingPrice,
28 input_cache_read_tokens_price: StripeBillingPrice,
29 output_tokens_price: StripeBillingPrice,
30}
31
32struct StripeBillingPrice {
33 id: stripe::PriceId,
34 meter_event_name: String,
35}
36
37impl StripeBilling {
38 pub fn new(client: Arc<stripe::Client>) -> Self {
39 Self {
40 client,
41 state: RwLock::default(),
42 }
43 }
44
45 pub async fn initialize(&self) -> Result<()> {
46 log::info!("StripeBilling: initializing");
47
48 let mut state = self.state.write().await;
49
50 let (meters, prices) = futures::try_join!(
51 StripeMeter::list(&self.client),
52 stripe::Price::list(
53 &self.client,
54 &stripe::ListPrices {
55 limit: Some(100),
56 ..Default::default()
57 }
58 )
59 )?;
60
61 for meter in meters.data {
62 state
63 .meters_by_event_name
64 .insert(meter.event_name.clone(), meter);
65 }
66
67 for price in prices.data {
68 if let Some(lookup_key) = price.lookup_key.clone() {
69 state.prices_by_lookup_key.insert(lookup_key, price.clone());
70 }
71
72 if let Some(recurring) = price.recurring {
73 if let Some(meter) = recurring.meter {
74 state.price_ids_by_meter_id.insert(meter, price.id);
75 }
76 }
77 }
78
79 log::info!("StripeBilling: initialized");
80
81 Ok(())
82 }
83
84 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
85 self.find_price_id_by_lookup_key("zed-pro").await
86 }
87
88 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
89 self.find_price_id_by_lookup_key("zed-free").await
90 }
91
92 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
93 self.state
94 .read()
95 .await
96 .prices_by_lookup_key
97 .get(lookup_key)
98 .map(|price| price.id.clone())
99 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
100 }
101
102 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
103 self.state
104 .read()
105 .await
106 .prices_by_lookup_key
107 .get(lookup_key)
108 .cloned()
109 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
110 }
111
112 pub async fn register_model_for_token_based_usage(
113 &self,
114 model: &llm::db::model::Model,
115 ) -> Result<StripeModelTokenPrices> {
116 let input_tokens_price = self
117 .get_or_insert_token_price(
118 &format!("model_{}/input_tokens", model.id),
119 &format!("{} (Input Tokens)", model.name),
120 Cents::new(model.price_per_million_input_tokens as u32),
121 )
122 .await?;
123 let input_cache_creation_tokens_price = self
124 .get_or_insert_token_price(
125 &format!("model_{}/input_cache_creation_tokens", model.id),
126 &format!("{} (Input Cache Creation Tokens)", model.name),
127 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
128 )
129 .await?;
130 let input_cache_read_tokens_price = self
131 .get_or_insert_token_price(
132 &format!("model_{}/input_cache_read_tokens", model.id),
133 &format!("{} (Input Cache Read Tokens)", model.name),
134 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
135 )
136 .await?;
137 let output_tokens_price = self
138 .get_or_insert_token_price(
139 &format!("model_{}/output_tokens", model.id),
140 &format!("{} (Output Tokens)", model.name),
141 Cents::new(model.price_per_million_output_tokens as u32),
142 )
143 .await?;
144 Ok(StripeModelTokenPrices {
145 input_tokens_price,
146 input_cache_creation_tokens_price,
147 input_cache_read_tokens_price,
148 output_tokens_price,
149 })
150 }
151
152 async fn get_or_insert_token_price(
153 &self,
154 meter_event_name: &str,
155 price_description: &str,
156 price_per_million_tokens: Cents,
157 ) -> Result<StripeBillingPrice> {
158 // Fast code path when the meter and the price already exist.
159 {
160 let state = self.state.read().await;
161 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
162 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
163 return Ok(StripeBillingPrice {
164 id: price_id.clone(),
165 meter_event_name: meter_event_name.to_string(),
166 });
167 }
168 }
169 }
170
171 let mut state = self.state.write().await;
172 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
173 meter.clone()
174 } else {
175 let meter = StripeMeter::create(
176 &self.client,
177 StripeCreateMeterParams {
178 default_aggregation: DefaultAggregation { formula: "sum" },
179 display_name: price_description.to_string(),
180 event_name: meter_event_name,
181 },
182 )
183 .await?;
184 state
185 .meters_by_event_name
186 .insert(meter_event_name.to_string(), meter.clone());
187 meter
188 };
189
190 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
191 price_id.clone()
192 } else {
193 let price = stripe::Price::create(
194 &self.client,
195 stripe::CreatePrice {
196 active: Some(true),
197 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
198 currency: stripe::Currency::USD,
199 currency_options: None,
200 custom_unit_amount: None,
201 expand: &[],
202 lookup_key: None,
203 metadata: None,
204 nickname: None,
205 product: None,
206 product_data: Some(stripe::CreatePriceProductData {
207 id: None,
208 active: Some(true),
209 metadata: None,
210 name: price_description.to_string(),
211 statement_descriptor: None,
212 tax_code: None,
213 unit_label: None,
214 }),
215 recurring: Some(stripe::CreatePriceRecurring {
216 aggregate_usage: None,
217 interval: stripe::CreatePriceRecurringInterval::Month,
218 interval_count: None,
219 trial_period_days: None,
220 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
221 meter: Some(meter.id.clone()),
222 }),
223 tax_behavior: None,
224 tiers: None,
225 tiers_mode: None,
226 transfer_lookup_key: None,
227 transform_quantity: None,
228 unit_amount: None,
229 unit_amount_decimal: Some(&format!(
230 "{:.12}",
231 price_per_million_tokens.0 as f64 / 1_000_000f64
232 )),
233 },
234 )
235 .await?;
236 state
237 .price_ids_by_meter_id
238 .insert(meter.id, price.id.clone());
239 price.id
240 };
241
242 Ok(StripeBillingPrice {
243 id: price_id,
244 meter_event_name: meter_event_name.to_string(),
245 })
246 }
247
248 pub async fn subscribe_to_price(
249 &self,
250 subscription_id: &stripe::SubscriptionId,
251 price: &stripe::Price,
252 ) -> Result<()> {
253 let subscription =
254 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
255
256 if subscription_contains_price(&subscription, &price.id) {
257 return Ok(());
258 }
259
260 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
261
262 let price_per_unit = price.unit_amount.unwrap_or_default();
263 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
264
265 stripe::Subscription::update(
266 &self.client,
267 subscription_id,
268 stripe::UpdateSubscription {
269 items: Some(vec![stripe::UpdateSubscriptionItems {
270 price: Some(price.id.to_string()),
271 ..Default::default()
272 }]),
273 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
274 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
275 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
276 },
277 }),
278 ..Default::default()
279 },
280 )
281 .await?;
282
283 Ok(())
284 }
285
286 pub async fn subscribe_to_model(
287 &self,
288 subscription_id: &stripe::SubscriptionId,
289 model: &StripeModelTokenPrices,
290 ) -> Result<()> {
291 let subscription =
292 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
293
294 let mut items = Vec::new();
295
296 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
297 items.push(stripe::UpdateSubscriptionItems {
298 price: Some(model.input_tokens_price.id.to_string()),
299 ..Default::default()
300 });
301 }
302
303 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
304 {
305 items.push(stripe::UpdateSubscriptionItems {
306 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
307 ..Default::default()
308 });
309 }
310
311 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
312 items.push(stripe::UpdateSubscriptionItems {
313 price: Some(model.input_cache_read_tokens_price.id.to_string()),
314 ..Default::default()
315 });
316 }
317
318 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
319 items.push(stripe::UpdateSubscriptionItems {
320 price: Some(model.output_tokens_price.id.to_string()),
321 ..Default::default()
322 });
323 }
324
325 if !items.is_empty() {
326 items.extend(subscription.items.data.iter().map(|item| {
327 stripe::UpdateSubscriptionItems {
328 id: Some(item.id.to_string()),
329 ..Default::default()
330 }
331 }));
332
333 stripe::Subscription::update(
334 &self.client,
335 subscription_id,
336 stripe::UpdateSubscription {
337 items: Some(items),
338 ..Default::default()
339 },
340 )
341 .await?;
342 }
343
344 Ok(())
345 }
346
347 pub async fn bill_model_token_usage(
348 &self,
349 customer_id: &stripe::CustomerId,
350 model: &StripeModelTokenPrices,
351 event: &llm::db::billing_event::Model,
352 ) -> Result<()> {
353 let timestamp = Utc::now().timestamp();
354
355 if event.input_tokens > 0 {
356 StripeMeterEvent::create(
357 &self.client,
358 StripeCreateMeterEventParams {
359 identifier: &format!("input_tokens/{}", event.idempotency_key),
360 event_name: &model.input_tokens_price.meter_event_name,
361 payload: StripeCreateMeterEventPayload {
362 value: event.input_tokens as u64,
363 stripe_customer_id: customer_id,
364 },
365 timestamp: Some(timestamp),
366 },
367 )
368 .await?;
369 }
370
371 if event.input_cache_creation_tokens > 0 {
372 StripeMeterEvent::create(
373 &self.client,
374 StripeCreateMeterEventParams {
375 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
376 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
377 payload: StripeCreateMeterEventPayload {
378 value: event.input_cache_creation_tokens as u64,
379 stripe_customer_id: customer_id,
380 },
381 timestamp: Some(timestamp),
382 },
383 )
384 .await?;
385 }
386
387 if event.input_cache_read_tokens > 0 {
388 StripeMeterEvent::create(
389 &self.client,
390 StripeCreateMeterEventParams {
391 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
392 event_name: &model.input_cache_read_tokens_price.meter_event_name,
393 payload: StripeCreateMeterEventPayload {
394 value: event.input_cache_read_tokens as u64,
395 stripe_customer_id: customer_id,
396 },
397 timestamp: Some(timestamp),
398 },
399 )
400 .await?;
401 }
402
403 if event.output_tokens > 0 {
404 StripeMeterEvent::create(
405 &self.client,
406 StripeCreateMeterEventParams {
407 identifier: &format!("output_tokens/{}", event.idempotency_key),
408 event_name: &model.output_tokens_price.meter_event_name,
409 payload: StripeCreateMeterEventPayload {
410 value: event.output_tokens as u64,
411 stripe_customer_id: customer_id,
412 },
413 timestamp: Some(timestamp),
414 },
415 )
416 .await?;
417 }
418
419 Ok(())
420 }
421
422 pub async fn bill_model_request_usage(
423 &self,
424 customer_id: &stripe::CustomerId,
425 event_name: &str,
426 requests: i32,
427 ) -> Result<()> {
428 let timestamp = Utc::now().timestamp();
429 let idempotency_key = Uuid::new_v4();
430
431 StripeMeterEvent::create(
432 &self.client,
433 StripeCreateMeterEventParams {
434 identifier: &format!("model_requests/{}", idempotency_key),
435 event_name,
436 payload: StripeCreateMeterEventPayload {
437 value: requests as u64,
438 stripe_customer_id: customer_id,
439 },
440 timestamp: Some(timestamp),
441 },
442 )
443 .await?;
444
445 Ok(())
446 }
447
448 pub async fn checkout(
449 &self,
450 customer_id: stripe::CustomerId,
451 github_login: &str,
452 model: &StripeModelTokenPrices,
453 success_url: &str,
454 ) -> Result<String> {
455 let first_of_next_month = Utc::now()
456 .checked_add_months(chrono::Months::new(1))
457 .unwrap()
458 .with_day(1)
459 .unwrap();
460
461 let mut params = stripe::CreateCheckoutSession::new();
462 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
463 params.customer = Some(customer_id);
464 params.client_reference_id = Some(github_login);
465 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
466 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
467 ..Default::default()
468 });
469 params.line_items = Some(
470 [
471 &model.input_tokens_price.id,
472 &model.input_cache_creation_tokens_price.id,
473 &model.input_cache_read_tokens_price.id,
474 &model.output_tokens_price.id,
475 ]
476 .into_iter()
477 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
478 price: Some(price_id.to_string()),
479 ..Default::default()
480 })
481 .collect(),
482 );
483 params.success_url = Some(success_url);
484
485 let session = stripe::CheckoutSession::create(&self.client, params).await?;
486 Ok(session.url.context("no checkout session URL")?)
487 }
488
489 pub async fn checkout_with_zed_pro(
490 &self,
491 customer_id: stripe::CustomerId,
492 github_login: &str,
493 success_url: &str,
494 ) -> Result<String> {
495 let zed_pro_price_id = self.zed_pro_price_id().await?;
496
497 let mut params = stripe::CreateCheckoutSession::new();
498 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
499 params.customer = Some(customer_id);
500 params.client_reference_id = Some(github_login);
501 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
502 price: Some(zed_pro_price_id.to_string()),
503 quantity: Some(1),
504 ..Default::default()
505 }]);
506 params.success_url = Some(success_url);
507
508 let session = stripe::CheckoutSession::create(&self.client, params).await?;
509 Ok(session.url.context("no checkout session URL")?)
510 }
511
512 pub async fn checkout_with_zed_pro_trial(
513 &self,
514 customer_id: stripe::CustomerId,
515 github_login: &str,
516 feature_flags: Vec<String>,
517 success_url: &str,
518 ) -> Result<String> {
519 let zed_pro_price_id = self.zed_pro_price_id().await?;
520
521 let eligible_for_extended_trial = feature_flags
522 .iter()
523 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
524
525 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
526
527 let mut subscription_metadata = std::collections::HashMap::new();
528 if eligible_for_extended_trial {
529 subscription_metadata.insert(
530 "promo_feature_flag".to_string(),
531 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
532 );
533 }
534
535 let mut params = stripe::CreateCheckoutSession::new();
536 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
537 trial_period_days: Some(trial_period_days),
538 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
539 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
540 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
541 }
542 }),
543 metadata: if !subscription_metadata.is_empty() {
544 Some(subscription_metadata)
545 } else {
546 None
547 },
548 ..Default::default()
549 });
550 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
551 params.payment_method_collection =
552 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
553 params.customer = Some(customer_id);
554 params.client_reference_id = Some(github_login);
555 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
556 price: Some(zed_pro_price_id.to_string()),
557 quantity: Some(1),
558 ..Default::default()
559 }]);
560 params.success_url = Some(success_url);
561
562 let session = stripe::CheckoutSession::create(&self.client, params).await?;
563 Ok(session.url.context("no checkout session URL")?)
564 }
565
566 pub async fn checkout_with_zed_free(
567 &self,
568 customer_id: stripe::CustomerId,
569 github_login: &str,
570 success_url: &str,
571 ) -> Result<String> {
572 let zed_free_price_id = self.zed_free_price_id().await?;
573
574 let mut params = stripe::CreateCheckoutSession::new();
575 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
576 params.customer = Some(customer_id);
577 params.client_reference_id = Some(github_login);
578 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
579 price: Some(zed_free_price_id.to_string()),
580 quantity: Some(1),
581 ..Default::default()
582 }]);
583 params.success_url = Some(success_url);
584
585 let session = stripe::CheckoutSession::create(&self.client, params).await?;
586 Ok(session.url.context("no checkout session URL")?)
587 }
588}
589
590#[derive(Serialize)]
591struct DefaultAggregation {
592 formula: &'static str,
593}
594
595#[derive(Serialize)]
596struct StripeCreateMeterParams<'a> {
597 default_aggregation: DefaultAggregation,
598 display_name: String,
599 event_name: &'a str,
600}
601
602#[derive(Clone, Deserialize)]
603struct StripeMeter {
604 id: String,
605 event_name: String,
606}
607
608impl StripeMeter {
609 pub fn create(
610 client: &stripe::Client,
611 params: StripeCreateMeterParams,
612 ) -> stripe::Response<Self> {
613 client.post_form("/billing/meters", params)
614 }
615
616 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
617 #[derive(Serialize)]
618 struct Params {
619 #[serde(skip_serializing_if = "Option::is_none")]
620 limit: Option<u64>,
621 }
622
623 client.get_query("/billing/meters", Params { limit: Some(100) })
624 }
625}
626
627#[derive(Deserialize)]
628struct StripeMeterEvent {
629 identifier: String,
630}
631
632impl StripeMeterEvent {
633 pub async fn create(
634 client: &stripe::Client,
635 params: StripeCreateMeterEventParams<'_>,
636 ) -> Result<Self, stripe::StripeError> {
637 let identifier = params.identifier;
638 match client.post_form("/billing/meter_events", params).await {
639 Ok(event) => Ok(event),
640 Err(stripe::StripeError::Stripe(error)) => {
641 if error.http_status == 400
642 && error
643 .message
644 .as_ref()
645 .map_or(false, |message| message.contains(identifier))
646 {
647 Ok(Self {
648 identifier: identifier.to_string(),
649 })
650 } else {
651 Err(stripe::StripeError::Stripe(error))
652 }
653 }
654 Err(error) => Err(error),
655 }
656 }
657}
658
659#[derive(Serialize)]
660struct StripeCreateMeterEventParams<'a> {
661 identifier: &'a str,
662 event_name: &'a str,
663 payload: StripeCreateMeterEventPayload<'a>,
664 timestamp: Option<i64>,
665}
666
667#[derive(Serialize)]
668struct StripeCreateMeterEventPayload<'a> {
669 value: u64,
670 stripe_customer_id: &'a stripe::CustomerId,
671}
672
673fn subscription_contains_price(
674 subscription: &stripe::Subscription,
675 price_id: &stripe::PriceId,
676) -> bool {
677 subscription.items.data.iter().any(|item| {
678 item.price
679 .as_ref()
680 .map_or(false, |price| price.id == *price_id)
681 })
682}