1use std::sync::Arc;
2
3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
4use crate::{Cents, Result};
5use anyhow::{Context as _, anyhow};
6use chrono::{Datelike, Utc};
7use collections::HashMap;
8use serde::{Deserialize, Serialize};
9use stripe::PriceId;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13pub struct StripeBilling {
14 state: RwLock<StripeBillingState>,
15 client: Arc<stripe::Client>,
16}
17
18#[derive(Default)]
19struct StripeBillingState {
20 meters_by_event_name: HashMap<String, StripeMeter>,
21 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
22 prices_by_lookup_key: HashMap<String, stripe::Price>,
23}
24
25pub struct StripeModelTokenPrices {
26 input_tokens_price: StripeBillingPrice,
27 input_cache_creation_tokens_price: StripeBillingPrice,
28 input_cache_read_tokens_price: StripeBillingPrice,
29 output_tokens_price: StripeBillingPrice,
30}
31
32struct StripeBillingPrice {
33 id: stripe::PriceId,
34 meter_event_name: String,
35}
36
37impl StripeBilling {
38 pub fn new(client: Arc<stripe::Client>) -> Self {
39 Self {
40 client,
41 state: RwLock::default(),
42 }
43 }
44
45 pub async fn initialize(&self) -> Result<()> {
46 log::info!("StripeBilling: initializing");
47
48 let mut state = self.state.write().await;
49
50 let (meters, prices) = futures::try_join!(
51 StripeMeter::list(&self.client),
52 stripe::Price::list(
53 &self.client,
54 &stripe::ListPrices {
55 limit: Some(100),
56 ..Default::default()
57 }
58 )
59 )?;
60
61 for meter in meters.data {
62 state
63 .meters_by_event_name
64 .insert(meter.event_name.clone(), meter);
65 }
66
67 for price in prices.data {
68 if let Some(lookup_key) = price.lookup_key.clone() {
69 state.prices_by_lookup_key.insert(lookup_key, price.clone());
70 }
71
72 if let Some(recurring) = price.recurring {
73 if let Some(meter) = recurring.meter {
74 state.price_ids_by_meter_id.insert(meter, price.id);
75 }
76 }
77 }
78
79 log::info!("StripeBilling: initialized");
80
81 Ok(())
82 }
83
84 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
85 self.state
86 .read()
87 .await
88 .prices_by_lookup_key
89 .get(lookup_key)
90 .cloned()
91 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
92 }
93
94 pub async fn register_model_for_token_based_usage(
95 &self,
96 model: &llm::db::model::Model,
97 ) -> Result<StripeModelTokenPrices> {
98 let input_tokens_price = self
99 .get_or_insert_token_price(
100 &format!("model_{}/input_tokens", model.id),
101 &format!("{} (Input Tokens)", model.name),
102 Cents::new(model.price_per_million_input_tokens as u32),
103 )
104 .await?;
105 let input_cache_creation_tokens_price = self
106 .get_or_insert_token_price(
107 &format!("model_{}/input_cache_creation_tokens", model.id),
108 &format!("{} (Input Cache Creation Tokens)", model.name),
109 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
110 )
111 .await?;
112 let input_cache_read_tokens_price = self
113 .get_or_insert_token_price(
114 &format!("model_{}/input_cache_read_tokens", model.id),
115 &format!("{} (Input Cache Read Tokens)", model.name),
116 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
117 )
118 .await?;
119 let output_tokens_price = self
120 .get_or_insert_token_price(
121 &format!("model_{}/output_tokens", model.id),
122 &format!("{} (Output Tokens)", model.name),
123 Cents::new(model.price_per_million_output_tokens as u32),
124 )
125 .await?;
126 Ok(StripeModelTokenPrices {
127 input_tokens_price,
128 input_cache_creation_tokens_price,
129 input_cache_read_tokens_price,
130 output_tokens_price,
131 })
132 }
133
134 async fn get_or_insert_token_price(
135 &self,
136 meter_event_name: &str,
137 price_description: &str,
138 price_per_million_tokens: Cents,
139 ) -> Result<StripeBillingPrice> {
140 // Fast code path when the meter and the price already exist.
141 {
142 let state = self.state.read().await;
143 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
144 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
145 return Ok(StripeBillingPrice {
146 id: price_id.clone(),
147 meter_event_name: meter_event_name.to_string(),
148 });
149 }
150 }
151 }
152
153 let mut state = self.state.write().await;
154 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
155 meter.clone()
156 } else {
157 let meter = StripeMeter::create(
158 &self.client,
159 StripeCreateMeterParams {
160 default_aggregation: DefaultAggregation { formula: "sum" },
161 display_name: price_description.to_string(),
162 event_name: meter_event_name,
163 },
164 )
165 .await?;
166 state
167 .meters_by_event_name
168 .insert(meter_event_name.to_string(), meter.clone());
169 meter
170 };
171
172 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
173 price_id.clone()
174 } else {
175 let price = stripe::Price::create(
176 &self.client,
177 stripe::CreatePrice {
178 active: Some(true),
179 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
180 currency: stripe::Currency::USD,
181 currency_options: None,
182 custom_unit_amount: None,
183 expand: &[],
184 lookup_key: None,
185 metadata: None,
186 nickname: None,
187 product: None,
188 product_data: Some(stripe::CreatePriceProductData {
189 id: None,
190 active: Some(true),
191 metadata: None,
192 name: price_description.to_string(),
193 statement_descriptor: None,
194 tax_code: None,
195 unit_label: None,
196 }),
197 recurring: Some(stripe::CreatePriceRecurring {
198 aggregate_usage: None,
199 interval: stripe::CreatePriceRecurringInterval::Month,
200 interval_count: None,
201 trial_period_days: None,
202 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
203 meter: Some(meter.id.clone()),
204 }),
205 tax_behavior: None,
206 tiers: None,
207 tiers_mode: None,
208 transfer_lookup_key: None,
209 transform_quantity: None,
210 unit_amount: None,
211 unit_amount_decimal: Some(&format!(
212 "{:.12}",
213 price_per_million_tokens.0 as f64 / 1_000_000f64
214 )),
215 },
216 )
217 .await?;
218 state
219 .price_ids_by_meter_id
220 .insert(meter.id, price.id.clone());
221 price.id
222 };
223
224 Ok(StripeBillingPrice {
225 id: price_id,
226 meter_event_name: meter_event_name.to_string(),
227 })
228 }
229
230 pub async fn subscribe_to_price(
231 &self,
232 subscription_id: &stripe::SubscriptionId,
233 price_id: &stripe::PriceId,
234 ) -> Result<()> {
235 let subscription =
236 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
237
238 if subscription_contains_price(&subscription, price_id) {
239 return Ok(());
240 }
241
242 stripe::Subscription::update(
243 &self.client,
244 subscription_id,
245 stripe::UpdateSubscription {
246 items: Some(vec![stripe::UpdateSubscriptionItems {
247 price: Some(price_id.to_string()),
248 ..Default::default()
249 }]),
250 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
251 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
252 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
253 },
254 }),
255 ..Default::default()
256 },
257 )
258 .await?;
259
260 Ok(())
261 }
262
263 pub async fn subscribe_to_model(
264 &self,
265 subscription_id: &stripe::SubscriptionId,
266 model: &StripeModelTokenPrices,
267 ) -> Result<()> {
268 let subscription =
269 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
270
271 let mut items = Vec::new();
272
273 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
274 items.push(stripe::UpdateSubscriptionItems {
275 price: Some(model.input_tokens_price.id.to_string()),
276 ..Default::default()
277 });
278 }
279
280 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
281 {
282 items.push(stripe::UpdateSubscriptionItems {
283 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
284 ..Default::default()
285 });
286 }
287
288 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
289 items.push(stripe::UpdateSubscriptionItems {
290 price: Some(model.input_cache_read_tokens_price.id.to_string()),
291 ..Default::default()
292 });
293 }
294
295 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
296 items.push(stripe::UpdateSubscriptionItems {
297 price: Some(model.output_tokens_price.id.to_string()),
298 ..Default::default()
299 });
300 }
301
302 if !items.is_empty() {
303 items.extend(subscription.items.data.iter().map(|item| {
304 stripe::UpdateSubscriptionItems {
305 id: Some(item.id.to_string()),
306 ..Default::default()
307 }
308 }));
309
310 stripe::Subscription::update(
311 &self.client,
312 subscription_id,
313 stripe::UpdateSubscription {
314 items: Some(items),
315 ..Default::default()
316 },
317 )
318 .await?;
319 }
320
321 Ok(())
322 }
323
324 pub async fn bill_model_token_usage(
325 &self,
326 customer_id: &stripe::CustomerId,
327 model: &StripeModelTokenPrices,
328 event: &llm::db::billing_event::Model,
329 ) -> Result<()> {
330 let timestamp = Utc::now().timestamp();
331
332 if event.input_tokens > 0 {
333 StripeMeterEvent::create(
334 &self.client,
335 StripeCreateMeterEventParams {
336 identifier: &format!("input_tokens/{}", event.idempotency_key),
337 event_name: &model.input_tokens_price.meter_event_name,
338 payload: StripeCreateMeterEventPayload {
339 value: event.input_tokens as u64,
340 stripe_customer_id: customer_id,
341 },
342 timestamp: Some(timestamp),
343 },
344 )
345 .await?;
346 }
347
348 if event.input_cache_creation_tokens > 0 {
349 StripeMeterEvent::create(
350 &self.client,
351 StripeCreateMeterEventParams {
352 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
353 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
354 payload: StripeCreateMeterEventPayload {
355 value: event.input_cache_creation_tokens as u64,
356 stripe_customer_id: customer_id,
357 },
358 timestamp: Some(timestamp),
359 },
360 )
361 .await?;
362 }
363
364 if event.input_cache_read_tokens > 0 {
365 StripeMeterEvent::create(
366 &self.client,
367 StripeCreateMeterEventParams {
368 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
369 event_name: &model.input_cache_read_tokens_price.meter_event_name,
370 payload: StripeCreateMeterEventPayload {
371 value: event.input_cache_read_tokens as u64,
372 stripe_customer_id: customer_id,
373 },
374 timestamp: Some(timestamp),
375 },
376 )
377 .await?;
378 }
379
380 if event.output_tokens > 0 {
381 StripeMeterEvent::create(
382 &self.client,
383 StripeCreateMeterEventParams {
384 identifier: &format!("output_tokens/{}", event.idempotency_key),
385 event_name: &model.output_tokens_price.meter_event_name,
386 payload: StripeCreateMeterEventPayload {
387 value: event.output_tokens as u64,
388 stripe_customer_id: customer_id,
389 },
390 timestamp: Some(timestamp),
391 },
392 )
393 .await?;
394 }
395
396 Ok(())
397 }
398
399 pub async fn bill_model_request_usage(
400 &self,
401 customer_id: &stripe::CustomerId,
402 event_name: &str,
403 requests: i32,
404 ) -> Result<()> {
405 let timestamp = Utc::now().timestamp();
406 let idempotency_key = Uuid::new_v4();
407
408 StripeMeterEvent::create(
409 &self.client,
410 StripeCreateMeterEventParams {
411 identifier: &format!("model_requests/{}", idempotency_key),
412 event_name,
413 payload: StripeCreateMeterEventPayload {
414 value: requests as u64,
415 stripe_customer_id: customer_id,
416 },
417 timestamp: Some(timestamp),
418 },
419 )
420 .await?;
421
422 Ok(())
423 }
424
425 pub async fn checkout(
426 &self,
427 customer_id: stripe::CustomerId,
428 github_login: &str,
429 model: &StripeModelTokenPrices,
430 success_url: &str,
431 ) -> Result<String> {
432 let first_of_next_month = Utc::now()
433 .checked_add_months(chrono::Months::new(1))
434 .unwrap()
435 .with_day(1)
436 .unwrap();
437
438 let mut params = stripe::CreateCheckoutSession::new();
439 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
440 params.customer = Some(customer_id);
441 params.client_reference_id = Some(github_login);
442 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
443 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
444 ..Default::default()
445 });
446 params.line_items = Some(
447 [
448 &model.input_tokens_price.id,
449 &model.input_cache_creation_tokens_price.id,
450 &model.input_cache_read_tokens_price.id,
451 &model.output_tokens_price.id,
452 ]
453 .into_iter()
454 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
455 price: Some(price_id.to_string()),
456 ..Default::default()
457 })
458 .collect(),
459 );
460 params.success_url = Some(success_url);
461
462 let session = stripe::CheckoutSession::create(&self.client, params).await?;
463 Ok(session.url.context("no checkout session URL")?)
464 }
465
466 pub async fn checkout_with_price(
467 &self,
468 price_id: PriceId,
469 customer_id: stripe::CustomerId,
470 github_login: &str,
471 success_url: &str,
472 ) -> Result<String> {
473 let mut params = stripe::CreateCheckoutSession::new();
474 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
475 params.customer = Some(customer_id);
476 params.client_reference_id = Some(github_login);
477 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
478 price: Some(price_id.to_string()),
479 quantity: Some(1),
480 ..Default::default()
481 }]);
482 params.success_url = Some(success_url);
483
484 let session = stripe::CheckoutSession::create(&self.client, params).await?;
485 Ok(session.url.context("no checkout session URL")?)
486 }
487
488 pub async fn checkout_with_zed_pro_trial(
489 &self,
490 zed_pro_price_id: PriceId,
491 customer_id: stripe::CustomerId,
492 github_login: &str,
493 feature_flags: Vec<String>,
494 success_url: &str,
495 ) -> Result<String> {
496 let eligible_for_extended_trial = feature_flags
497 .iter()
498 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
499
500 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
501
502 let mut subscription_metadata = std::collections::HashMap::new();
503 if eligible_for_extended_trial {
504 subscription_metadata.insert(
505 "promo_feature_flag".to_string(),
506 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
507 );
508 }
509
510 let mut params = stripe::CreateCheckoutSession::new();
511 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
512 trial_period_days: Some(trial_period_days),
513 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
514 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
515 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
516 }
517 }),
518 metadata: if !subscription_metadata.is_empty() {
519 Some(subscription_metadata)
520 } else {
521 None
522 },
523 ..Default::default()
524 });
525 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
526 params.payment_method_collection =
527 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
528 params.customer = Some(customer_id);
529 params.client_reference_id = Some(github_login);
530 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
531 price: Some(zed_pro_price_id.to_string()),
532 quantity: Some(1),
533 ..Default::default()
534 }]);
535 params.success_url = Some(success_url);
536
537 let session = stripe::CheckoutSession::create(&self.client, params).await?;
538 Ok(session.url.context("no checkout session URL")?)
539 }
540}
541
542#[derive(Serialize)]
543struct DefaultAggregation {
544 formula: &'static str,
545}
546
547#[derive(Serialize)]
548struct StripeCreateMeterParams<'a> {
549 default_aggregation: DefaultAggregation,
550 display_name: String,
551 event_name: &'a str,
552}
553
554#[derive(Clone, Deserialize)]
555struct StripeMeter {
556 id: String,
557 event_name: String,
558}
559
560impl StripeMeter {
561 pub fn create(
562 client: &stripe::Client,
563 params: StripeCreateMeterParams,
564 ) -> stripe::Response<Self> {
565 client.post_form("/billing/meters", params)
566 }
567
568 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
569 #[derive(Serialize)]
570 struct Params {
571 #[serde(skip_serializing_if = "Option::is_none")]
572 limit: Option<u64>,
573 }
574
575 client.get_query("/billing/meters", Params { limit: Some(100) })
576 }
577}
578
579#[derive(Deserialize)]
580struct StripeMeterEvent {
581 identifier: String,
582}
583
584impl StripeMeterEvent {
585 pub async fn create(
586 client: &stripe::Client,
587 params: StripeCreateMeterEventParams<'_>,
588 ) -> Result<Self, stripe::StripeError> {
589 let identifier = params.identifier;
590 match client.post_form("/billing/meter_events", params).await {
591 Ok(event) => Ok(event),
592 Err(stripe::StripeError::Stripe(error)) => {
593 if error.http_status == 400
594 && error
595 .message
596 .as_ref()
597 .map_or(false, |message| message.contains(identifier))
598 {
599 Ok(Self {
600 identifier: identifier.to_string(),
601 })
602 } else {
603 Err(stripe::StripeError::Stripe(error))
604 }
605 }
606 Err(error) => Err(error),
607 }
608 }
609}
610
611#[derive(Serialize)]
612struct StripeCreateMeterEventParams<'a> {
613 identifier: &'a str,
614 event_name: &'a str,
615 payload: StripeCreateMeterEventPayload<'a>,
616 timestamp: Option<i64>,
617}
618
619#[derive(Serialize)]
620struct StripeCreateMeterEventPayload<'a> {
621 value: u64,
622 stripe_customer_id: &'a stripe::CustomerId,
623}
624
625fn subscription_contains_price(
626 subscription: &stripe::Subscription,
627 price_id: &stripe::PriceId,
628) -> bool {
629 subscription.items.data.iter().any(|item| {
630 item.price
631 .as_ref()
632 .map_or(false, |price| price.id == *price_id)
633 })
634}