1use std::sync::Arc;
2
3use crate::{Cents, Result, llm};
4use anyhow::{Context as _, anyhow};
5use chrono::{Datelike, Utc};
6use collections::HashMap;
7use serde::{Deserialize, Serialize};
8use stripe::PriceId;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12pub struct StripeBilling {
13 state: RwLock<StripeBillingState>,
14 client: Arc<stripe::Client>,
15}
16
17#[derive(Default)]
18struct StripeBillingState {
19 meters_by_event_name: HashMap<String, StripeMeter>,
20 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
21 prices_by_lookup_key: HashMap<String, stripe::Price>,
22}
23
24pub struct StripeModelTokenPrices {
25 input_tokens_price: StripeBillingPrice,
26 input_cache_creation_tokens_price: StripeBillingPrice,
27 input_cache_read_tokens_price: StripeBillingPrice,
28 output_tokens_price: StripeBillingPrice,
29}
30
31struct StripeBillingPrice {
32 id: stripe::PriceId,
33 meter_event_name: String,
34}
35
36impl StripeBilling {
37 pub fn new(client: Arc<stripe::Client>) -> Self {
38 Self {
39 client,
40 state: RwLock::default(),
41 }
42 }
43
44 pub async fn initialize(&self) -> Result<()> {
45 log::info!("StripeBilling: initializing");
46
47 let mut state = self.state.write().await;
48
49 let (meters, prices) = futures::try_join!(
50 StripeMeter::list(&self.client),
51 stripe::Price::list(
52 &self.client,
53 &stripe::ListPrices {
54 limit: Some(100),
55 ..Default::default()
56 }
57 )
58 )?;
59
60 for meter in meters.data {
61 state
62 .meters_by_event_name
63 .insert(meter.event_name.clone(), meter);
64 }
65
66 for price in prices.data {
67 if let Some(lookup_key) = price.lookup_key.clone() {
68 state.prices_by_lookup_key.insert(lookup_key, price.clone());
69 }
70
71 if let Some(recurring) = price.recurring {
72 if let Some(meter) = recurring.meter {
73 state.price_ids_by_meter_id.insert(meter, price.id);
74 }
75 }
76 }
77
78 log::info!("StripeBilling: initialized");
79
80 Ok(())
81 }
82
83 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
84 self.state
85 .read()
86 .await
87 .prices_by_lookup_key
88 .get(lookup_key)
89 .cloned()
90 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
91 }
92
93 pub async fn register_model_for_token_based_usage(
94 &self,
95 model: &llm::db::model::Model,
96 ) -> Result<StripeModelTokenPrices> {
97 let input_tokens_price = self
98 .get_or_insert_token_price(
99 &format!("model_{}/input_tokens", model.id),
100 &format!("{} (Input Tokens)", model.name),
101 Cents::new(model.price_per_million_input_tokens as u32),
102 )
103 .await?;
104 let input_cache_creation_tokens_price = self
105 .get_or_insert_token_price(
106 &format!("model_{}/input_cache_creation_tokens", model.id),
107 &format!("{} (Input Cache Creation Tokens)", model.name),
108 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
109 )
110 .await?;
111 let input_cache_read_tokens_price = self
112 .get_or_insert_token_price(
113 &format!("model_{}/input_cache_read_tokens", model.id),
114 &format!("{} (Input Cache Read Tokens)", model.name),
115 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
116 )
117 .await?;
118 let output_tokens_price = self
119 .get_or_insert_token_price(
120 &format!("model_{}/output_tokens", model.id),
121 &format!("{} (Output Tokens)", model.name),
122 Cents::new(model.price_per_million_output_tokens as u32),
123 )
124 .await?;
125 Ok(StripeModelTokenPrices {
126 input_tokens_price,
127 input_cache_creation_tokens_price,
128 input_cache_read_tokens_price,
129 output_tokens_price,
130 })
131 }
132
133 async fn get_or_insert_token_price(
134 &self,
135 meter_event_name: &str,
136 price_description: &str,
137 price_per_million_tokens: Cents,
138 ) -> Result<StripeBillingPrice> {
139 // Fast code path when the meter and the price already exist.
140 {
141 let state = self.state.read().await;
142 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
143 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
144 return Ok(StripeBillingPrice {
145 id: price_id.clone(),
146 meter_event_name: meter_event_name.to_string(),
147 });
148 }
149 }
150 }
151
152 let mut state = self.state.write().await;
153 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
154 meter.clone()
155 } else {
156 let meter = StripeMeter::create(
157 &self.client,
158 StripeCreateMeterParams {
159 default_aggregation: DefaultAggregation { formula: "sum" },
160 display_name: price_description.to_string(),
161 event_name: meter_event_name,
162 },
163 )
164 .await?;
165 state
166 .meters_by_event_name
167 .insert(meter_event_name.to_string(), meter.clone());
168 meter
169 };
170
171 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
172 price_id.clone()
173 } else {
174 let price = stripe::Price::create(
175 &self.client,
176 stripe::CreatePrice {
177 active: Some(true),
178 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
179 currency: stripe::Currency::USD,
180 currency_options: None,
181 custom_unit_amount: None,
182 expand: &[],
183 lookup_key: None,
184 metadata: None,
185 nickname: None,
186 product: None,
187 product_data: Some(stripe::CreatePriceProductData {
188 id: None,
189 active: Some(true),
190 metadata: None,
191 name: price_description.to_string(),
192 statement_descriptor: None,
193 tax_code: None,
194 unit_label: None,
195 }),
196 recurring: Some(stripe::CreatePriceRecurring {
197 aggregate_usage: None,
198 interval: stripe::CreatePriceRecurringInterval::Month,
199 interval_count: None,
200 trial_period_days: None,
201 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
202 meter: Some(meter.id.clone()),
203 }),
204 tax_behavior: None,
205 tiers: None,
206 tiers_mode: None,
207 transfer_lookup_key: None,
208 transform_quantity: None,
209 unit_amount: None,
210 unit_amount_decimal: Some(&format!(
211 "{:.12}",
212 price_per_million_tokens.0 as f64 / 1_000_000f64
213 )),
214 },
215 )
216 .await?;
217 state
218 .price_ids_by_meter_id
219 .insert(meter.id, price.id.clone());
220 price.id
221 };
222
223 Ok(StripeBillingPrice {
224 id: price_id,
225 meter_event_name: meter_event_name.to_string(),
226 })
227 }
228
229 pub async fn subscribe_to_price(
230 &self,
231 subscription_id: &stripe::SubscriptionId,
232 price_id: &stripe::PriceId,
233 ) -> Result<()> {
234 let subscription =
235 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
236
237 if subscription_contains_price(&subscription, price_id) {
238 return Ok(());
239 }
240
241 stripe::Subscription::update(
242 &self.client,
243 subscription_id,
244 stripe::UpdateSubscription {
245 items: Some(vec![stripe::UpdateSubscriptionItems {
246 price: Some(price_id.to_string()),
247 ..Default::default()
248 }]),
249 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
250 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
251 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
252 },
253 }),
254 ..Default::default()
255 },
256 )
257 .await?;
258
259 Ok(())
260 }
261
262 pub async fn subscribe_to_model(
263 &self,
264 subscription_id: &stripe::SubscriptionId,
265 model: &StripeModelTokenPrices,
266 ) -> Result<()> {
267 let subscription =
268 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
269
270 let mut items = Vec::new();
271
272 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
273 items.push(stripe::UpdateSubscriptionItems {
274 price: Some(model.input_tokens_price.id.to_string()),
275 ..Default::default()
276 });
277 }
278
279 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
280 {
281 items.push(stripe::UpdateSubscriptionItems {
282 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
283 ..Default::default()
284 });
285 }
286
287 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
288 items.push(stripe::UpdateSubscriptionItems {
289 price: Some(model.input_cache_read_tokens_price.id.to_string()),
290 ..Default::default()
291 });
292 }
293
294 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
295 items.push(stripe::UpdateSubscriptionItems {
296 price: Some(model.output_tokens_price.id.to_string()),
297 ..Default::default()
298 });
299 }
300
301 if !items.is_empty() {
302 items.extend(subscription.items.data.iter().map(|item| {
303 stripe::UpdateSubscriptionItems {
304 id: Some(item.id.to_string()),
305 ..Default::default()
306 }
307 }));
308
309 stripe::Subscription::update(
310 &self.client,
311 subscription_id,
312 stripe::UpdateSubscription {
313 items: Some(items),
314 ..Default::default()
315 },
316 )
317 .await?;
318 }
319
320 Ok(())
321 }
322
323 pub async fn bill_model_token_usage(
324 &self,
325 customer_id: &stripe::CustomerId,
326 model: &StripeModelTokenPrices,
327 event: &llm::db::billing_event::Model,
328 ) -> Result<()> {
329 let timestamp = Utc::now().timestamp();
330
331 if event.input_tokens > 0 {
332 StripeMeterEvent::create(
333 &self.client,
334 StripeCreateMeterEventParams {
335 identifier: &format!("input_tokens/{}", event.idempotency_key),
336 event_name: &model.input_tokens_price.meter_event_name,
337 payload: StripeCreateMeterEventPayload {
338 value: event.input_tokens as u64,
339 stripe_customer_id: customer_id,
340 },
341 timestamp: Some(timestamp),
342 },
343 )
344 .await?;
345 }
346
347 if event.input_cache_creation_tokens > 0 {
348 StripeMeterEvent::create(
349 &self.client,
350 StripeCreateMeterEventParams {
351 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
352 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
353 payload: StripeCreateMeterEventPayload {
354 value: event.input_cache_creation_tokens as u64,
355 stripe_customer_id: customer_id,
356 },
357 timestamp: Some(timestamp),
358 },
359 )
360 .await?;
361 }
362
363 if event.input_cache_read_tokens > 0 {
364 StripeMeterEvent::create(
365 &self.client,
366 StripeCreateMeterEventParams {
367 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
368 event_name: &model.input_cache_read_tokens_price.meter_event_name,
369 payload: StripeCreateMeterEventPayload {
370 value: event.input_cache_read_tokens as u64,
371 stripe_customer_id: customer_id,
372 },
373 timestamp: Some(timestamp),
374 },
375 )
376 .await?;
377 }
378
379 if event.output_tokens > 0 {
380 StripeMeterEvent::create(
381 &self.client,
382 StripeCreateMeterEventParams {
383 identifier: &format!("output_tokens/{}", event.idempotency_key),
384 event_name: &model.output_tokens_price.meter_event_name,
385 payload: StripeCreateMeterEventPayload {
386 value: event.output_tokens as u64,
387 stripe_customer_id: customer_id,
388 },
389 timestamp: Some(timestamp),
390 },
391 )
392 .await?;
393 }
394
395 Ok(())
396 }
397
398 pub async fn bill_model_request_usage(
399 &self,
400 customer_id: &stripe::CustomerId,
401 event_name: &str,
402 requests: i32,
403 ) -> Result<()> {
404 let timestamp = Utc::now().timestamp();
405 let idempotency_key = Uuid::new_v4();
406
407 StripeMeterEvent::create(
408 &self.client,
409 StripeCreateMeterEventParams {
410 identifier: &format!("model_requests/{}", idempotency_key),
411 event_name,
412 payload: StripeCreateMeterEventPayload {
413 value: requests as u64,
414 stripe_customer_id: customer_id,
415 },
416 timestamp: Some(timestamp),
417 },
418 )
419 .await?;
420
421 Ok(())
422 }
423
424 pub async fn checkout(
425 &self,
426 customer_id: stripe::CustomerId,
427 github_login: &str,
428 model: &StripeModelTokenPrices,
429 success_url: &str,
430 ) -> Result<String> {
431 let first_of_next_month = Utc::now()
432 .checked_add_months(chrono::Months::new(1))
433 .unwrap()
434 .with_day(1)
435 .unwrap();
436
437 let mut params = stripe::CreateCheckoutSession::new();
438 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
439 params.customer = Some(customer_id);
440 params.client_reference_id = Some(github_login);
441 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
442 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
443 ..Default::default()
444 });
445 params.line_items = Some(
446 [
447 &model.input_tokens_price.id,
448 &model.input_cache_creation_tokens_price.id,
449 &model.input_cache_read_tokens_price.id,
450 &model.output_tokens_price.id,
451 ]
452 .into_iter()
453 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
454 price: Some(price_id.to_string()),
455 ..Default::default()
456 })
457 .collect(),
458 );
459 params.success_url = Some(success_url);
460
461 let session = stripe::CheckoutSession::create(&self.client, params).await?;
462 Ok(session.url.context("no checkout session URL")?)
463 }
464
465 pub async fn checkout_with_price(
466 &self,
467 price_id: PriceId,
468 customer_id: stripe::CustomerId,
469 github_login: &str,
470 success_url: &str,
471 ) -> Result<String> {
472 let mut params = stripe::CreateCheckoutSession::new();
473 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
474 params.customer = Some(customer_id);
475 params.client_reference_id = Some(github_login);
476 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
477 price: Some(price_id.to_string()),
478 quantity: Some(1),
479 ..Default::default()
480 }]);
481 params.success_url = Some(success_url);
482
483 let session = stripe::CheckoutSession::create(&self.client, params).await?;
484 Ok(session.url.context("no checkout session URL")?)
485 }
486
487 pub async fn checkout_with_zed_pro_trial(
488 &self,
489 zed_pro_price_id: PriceId,
490 customer_id: stripe::CustomerId,
491 github_login: &str,
492 success_url: &str,
493 ) -> Result<String> {
494 let mut params = stripe::CreateCheckoutSession::new();
495 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
496 trial_period_days: Some(14),
497 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
498 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
499 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
500 }
501 }),
502 ..Default::default()
503 });
504 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
505 params.payment_method_collection =
506 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
507 params.customer = Some(customer_id);
508 params.client_reference_id = Some(github_login);
509 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
510 price: Some(zed_pro_price_id.to_string()),
511 quantity: Some(1),
512 ..Default::default()
513 }]);
514 params.success_url = Some(success_url);
515
516 let session = stripe::CheckoutSession::create(&self.client, params).await?;
517 Ok(session.url.context("no checkout session URL")?)
518 }
519}
520
521#[derive(Serialize)]
522struct DefaultAggregation {
523 formula: &'static str,
524}
525
526#[derive(Serialize)]
527struct StripeCreateMeterParams<'a> {
528 default_aggregation: DefaultAggregation,
529 display_name: String,
530 event_name: &'a str,
531}
532
533#[derive(Clone, Deserialize)]
534struct StripeMeter {
535 id: String,
536 event_name: String,
537}
538
539impl StripeMeter {
540 pub fn create(
541 client: &stripe::Client,
542 params: StripeCreateMeterParams,
543 ) -> stripe::Response<Self> {
544 client.post_form("/billing/meters", params)
545 }
546
547 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
548 #[derive(Serialize)]
549 struct Params {
550 #[serde(skip_serializing_if = "Option::is_none")]
551 limit: Option<u64>,
552 }
553
554 client.get_query("/billing/meters", Params { limit: Some(100) })
555 }
556}
557
558#[derive(Deserialize)]
559struct StripeMeterEvent {
560 identifier: String,
561}
562
563impl StripeMeterEvent {
564 pub async fn create(
565 client: &stripe::Client,
566 params: StripeCreateMeterEventParams<'_>,
567 ) -> Result<Self, stripe::StripeError> {
568 let identifier = params.identifier;
569 match client.post_form("/billing/meter_events", params).await {
570 Ok(event) => Ok(event),
571 Err(stripe::StripeError::Stripe(error)) => {
572 if error.http_status == 400
573 && error
574 .message
575 .as_ref()
576 .map_or(false, |message| message.contains(identifier))
577 {
578 Ok(Self {
579 identifier: identifier.to_string(),
580 })
581 } else {
582 Err(stripe::StripeError::Stripe(error))
583 }
584 }
585 Err(error) => Err(error),
586 }
587 }
588}
589
590#[derive(Serialize)]
591struct StripeCreateMeterEventParams<'a> {
592 identifier: &'a str,
593 event_name: &'a str,
594 payload: StripeCreateMeterEventPayload<'a>,
595 timestamp: Option<i64>,
596}
597
598#[derive(Serialize)]
599struct StripeCreateMeterEventPayload<'a> {
600 value: u64,
601 stripe_customer_id: &'a stripe::CustomerId,
602}
603
604fn subscription_contains_price(
605 subscription: &stripe::Subscription,
606 price_id: &stripe::PriceId,
607) -> bool {
608 subscription.items.data.iter().any(|item| {
609 item.price
610 .as_ref()
611 .map_or(false, |price| price.id == *price_id)
612 })
613}