1use std::sync::Arc;
2
3use crate::{Cents, Result, llm};
4use anyhow::Context as _;
5use chrono::{Datelike, Utc};
6use collections::HashMap;
7use serde::{Deserialize, Serialize};
8use stripe::PriceId;
9use tokio::sync::RwLock;
10
11pub struct StripeBilling {
12 state: RwLock<StripeBillingState>,
13 client: Arc<stripe::Client>,
14}
15
16#[derive(Default)]
17struct StripeBillingState {
18 meters_by_event_name: HashMap<String, StripeMeter>,
19 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
20}
21
22pub struct StripeModel {
23 input_tokens_price: StripeBillingPrice,
24 input_cache_creation_tokens_price: StripeBillingPrice,
25 input_cache_read_tokens_price: StripeBillingPrice,
26 output_tokens_price: StripeBillingPrice,
27}
28
29struct StripeBillingPrice {
30 id: stripe::PriceId,
31 meter_event_name: String,
32}
33
34impl StripeBilling {
35 pub fn new(client: Arc<stripe::Client>) -> Self {
36 Self {
37 client,
38 state: RwLock::default(),
39 }
40 }
41
42 pub async fn initialize(&self) -> Result<()> {
43 log::info!("StripeBilling: initializing");
44
45 let mut state = self.state.write().await;
46
47 let (meters, prices) = futures::try_join!(
48 StripeMeter::list(&self.client),
49 stripe::Price::list(
50 &self.client,
51 &stripe::ListPrices {
52 limit: Some(100),
53 ..Default::default()
54 }
55 )
56 )?;
57
58 for meter in meters.data {
59 state
60 .meters_by_event_name
61 .insert(meter.event_name.clone(), meter);
62 }
63
64 for price in prices.data {
65 if let Some(recurring) = price.recurring {
66 if let Some(meter) = recurring.meter {
67 state.price_ids_by_meter_id.insert(meter, price.id);
68 }
69 }
70 }
71
72 log::info!("StripeBilling: initialized");
73
74 Ok(())
75 }
76
77 pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
78 let input_tokens_price = self
79 .get_or_insert_price(
80 &format!("model_{}/input_tokens", model.id),
81 &format!("{} (Input Tokens)", model.name),
82 Cents::new(model.price_per_million_input_tokens as u32),
83 )
84 .await?;
85 let input_cache_creation_tokens_price = self
86 .get_or_insert_price(
87 &format!("model_{}/input_cache_creation_tokens", model.id),
88 &format!("{} (Input Cache Creation Tokens)", model.name),
89 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
90 )
91 .await?;
92 let input_cache_read_tokens_price = self
93 .get_or_insert_price(
94 &format!("model_{}/input_cache_read_tokens", model.id),
95 &format!("{} (Input Cache Read Tokens)", model.name),
96 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
97 )
98 .await?;
99 let output_tokens_price = self
100 .get_or_insert_price(
101 &format!("model_{}/output_tokens", model.id),
102 &format!("{} (Output Tokens)", model.name),
103 Cents::new(model.price_per_million_output_tokens as u32),
104 )
105 .await?;
106 Ok(StripeModel {
107 input_tokens_price,
108 input_cache_creation_tokens_price,
109 input_cache_read_tokens_price,
110 output_tokens_price,
111 })
112 }
113
114 async fn get_or_insert_price(
115 &self,
116 meter_event_name: &str,
117 price_description: &str,
118 price_per_million_tokens: Cents,
119 ) -> Result<StripeBillingPrice> {
120 // Fast code path when the meter and the price already exist.
121 {
122 let state = self.state.read().await;
123 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
124 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
125 return Ok(StripeBillingPrice {
126 id: price_id.clone(),
127 meter_event_name: meter_event_name.to_string(),
128 });
129 }
130 }
131 }
132
133 let mut state = self.state.write().await;
134 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
135 meter.clone()
136 } else {
137 let meter = StripeMeter::create(
138 &self.client,
139 StripeCreateMeterParams {
140 default_aggregation: DefaultAggregation { formula: "sum" },
141 display_name: price_description.to_string(),
142 event_name: meter_event_name,
143 },
144 )
145 .await?;
146 state
147 .meters_by_event_name
148 .insert(meter_event_name.to_string(), meter.clone());
149 meter
150 };
151
152 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
153 price_id.clone()
154 } else {
155 let price = stripe::Price::create(
156 &self.client,
157 stripe::CreatePrice {
158 active: Some(true),
159 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
160 currency: stripe::Currency::USD,
161 currency_options: None,
162 custom_unit_amount: None,
163 expand: &[],
164 lookup_key: None,
165 metadata: None,
166 nickname: None,
167 product: None,
168 product_data: Some(stripe::CreatePriceProductData {
169 id: None,
170 active: Some(true),
171 metadata: None,
172 name: price_description.to_string(),
173 statement_descriptor: None,
174 tax_code: None,
175 unit_label: None,
176 }),
177 recurring: Some(stripe::CreatePriceRecurring {
178 aggregate_usage: None,
179 interval: stripe::CreatePriceRecurringInterval::Month,
180 interval_count: None,
181 trial_period_days: None,
182 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
183 meter: Some(meter.id.clone()),
184 }),
185 tax_behavior: None,
186 tiers: None,
187 tiers_mode: None,
188 transfer_lookup_key: None,
189 transform_quantity: None,
190 unit_amount: None,
191 unit_amount_decimal: Some(&format!(
192 "{:.12}",
193 price_per_million_tokens.0 as f64 / 1_000_000f64
194 )),
195 },
196 )
197 .await?;
198 state
199 .price_ids_by_meter_id
200 .insert(meter.id, price.id.clone());
201 price.id
202 };
203
204 Ok(StripeBillingPrice {
205 id: price_id,
206 meter_event_name: meter_event_name.to_string(),
207 })
208 }
209
210 pub async fn subscribe_to_model(
211 &self,
212 subscription_id: &stripe::SubscriptionId,
213 model: &StripeModel,
214 ) -> Result<()> {
215 let subscription =
216 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
217
218 let mut items = Vec::new();
219
220 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
221 items.push(stripe::UpdateSubscriptionItems {
222 price: Some(model.input_tokens_price.id.to_string()),
223 ..Default::default()
224 });
225 }
226
227 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
228 {
229 items.push(stripe::UpdateSubscriptionItems {
230 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
231 ..Default::default()
232 });
233 }
234
235 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
236 items.push(stripe::UpdateSubscriptionItems {
237 price: Some(model.input_cache_read_tokens_price.id.to_string()),
238 ..Default::default()
239 });
240 }
241
242 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
243 items.push(stripe::UpdateSubscriptionItems {
244 price: Some(model.output_tokens_price.id.to_string()),
245 ..Default::default()
246 });
247 }
248
249 if !items.is_empty() {
250 items.extend(subscription.items.data.iter().map(|item| {
251 stripe::UpdateSubscriptionItems {
252 id: Some(item.id.to_string()),
253 ..Default::default()
254 }
255 }));
256
257 stripe::Subscription::update(
258 &self.client,
259 subscription_id,
260 stripe::UpdateSubscription {
261 items: Some(items),
262 ..Default::default()
263 },
264 )
265 .await?;
266 }
267
268 Ok(())
269 }
270
271 pub async fn bill_model_usage(
272 &self,
273 customer_id: &stripe::CustomerId,
274 model: &StripeModel,
275 event: &llm::db::billing_event::Model,
276 ) -> Result<()> {
277 let timestamp = Utc::now().timestamp();
278
279 if event.input_tokens > 0 {
280 StripeMeterEvent::create(
281 &self.client,
282 StripeCreateMeterEventParams {
283 identifier: &format!("input_tokens/{}", event.idempotency_key),
284 event_name: &model.input_tokens_price.meter_event_name,
285 payload: StripeCreateMeterEventPayload {
286 value: event.input_tokens as u64,
287 stripe_customer_id: customer_id,
288 },
289 timestamp: Some(timestamp),
290 },
291 )
292 .await?;
293 }
294
295 if event.input_cache_creation_tokens > 0 {
296 StripeMeterEvent::create(
297 &self.client,
298 StripeCreateMeterEventParams {
299 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
300 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
301 payload: StripeCreateMeterEventPayload {
302 value: event.input_cache_creation_tokens as u64,
303 stripe_customer_id: customer_id,
304 },
305 timestamp: Some(timestamp),
306 },
307 )
308 .await?;
309 }
310
311 if event.input_cache_read_tokens > 0 {
312 StripeMeterEvent::create(
313 &self.client,
314 StripeCreateMeterEventParams {
315 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
316 event_name: &model.input_cache_read_tokens_price.meter_event_name,
317 payload: StripeCreateMeterEventPayload {
318 value: event.input_cache_read_tokens as u64,
319 stripe_customer_id: customer_id,
320 },
321 timestamp: Some(timestamp),
322 },
323 )
324 .await?;
325 }
326
327 if event.output_tokens > 0 {
328 StripeMeterEvent::create(
329 &self.client,
330 StripeCreateMeterEventParams {
331 identifier: &format!("output_tokens/{}", event.idempotency_key),
332 event_name: &model.output_tokens_price.meter_event_name,
333 payload: StripeCreateMeterEventPayload {
334 value: event.output_tokens as u64,
335 stripe_customer_id: customer_id,
336 },
337 timestamp: Some(timestamp),
338 },
339 )
340 .await?;
341 }
342
343 Ok(())
344 }
345
346 pub async fn checkout(
347 &self,
348 customer_id: stripe::CustomerId,
349 github_login: &str,
350 model: &StripeModel,
351 success_url: &str,
352 ) -> Result<String> {
353 let first_of_next_month = Utc::now()
354 .checked_add_months(chrono::Months::new(1))
355 .unwrap()
356 .with_day(1)
357 .unwrap();
358
359 let mut params = stripe::CreateCheckoutSession::new();
360 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
361 params.customer = Some(customer_id);
362 params.client_reference_id = Some(github_login);
363 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
364 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
365 ..Default::default()
366 });
367 params.line_items = Some(
368 [
369 &model.input_tokens_price.id,
370 &model.input_cache_creation_tokens_price.id,
371 &model.input_cache_read_tokens_price.id,
372 &model.output_tokens_price.id,
373 ]
374 .into_iter()
375 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
376 price: Some(price_id.to_string()),
377 ..Default::default()
378 })
379 .collect(),
380 );
381 params.success_url = Some(success_url);
382
383 let session = stripe::CheckoutSession::create(&self.client, params).await?;
384 Ok(session.url.context("no checkout session URL")?)
385 }
386
387 pub async fn checkout_with_price(
388 &self,
389 price_id: PriceId,
390 customer_id: stripe::CustomerId,
391 github_login: &str,
392 success_url: &str,
393 ) -> Result<String> {
394 let mut params = stripe::CreateCheckoutSession::new();
395 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
396 params.customer = Some(customer_id);
397 params.client_reference_id = Some(github_login);
398 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
399 price: Some(price_id.to_string()),
400 quantity: Some(1),
401 ..Default::default()
402 }]);
403 params.success_url = Some(success_url);
404
405 let session = stripe::CheckoutSession::create(&self.client, params).await?;
406 Ok(session.url.context("no checkout session URL")?)
407 }
408
409 pub async fn checkout_with_zed_pro_trial(
410 &self,
411 zed_pro_price_id: PriceId,
412 customer_id: stripe::CustomerId,
413 github_login: &str,
414 success_url: &str,
415 ) -> Result<String> {
416 let mut params = stripe::CreateCheckoutSession::new();
417 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
418 trial_period_days: Some(14),
419 ..Default::default()
420 });
421 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
422 params.customer = Some(customer_id);
423 params.client_reference_id = Some(github_login);
424 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
425 price: Some(zed_pro_price_id.to_string()),
426 quantity: Some(1),
427 ..Default::default()
428 }]);
429 params.success_url = Some(success_url);
430
431 let session = stripe::CheckoutSession::create(&self.client, params).await?;
432 Ok(session.url.context("no checkout session URL")?)
433 }
434}
435
436#[derive(Serialize)]
437struct DefaultAggregation {
438 formula: &'static str,
439}
440
441#[derive(Serialize)]
442struct StripeCreateMeterParams<'a> {
443 default_aggregation: DefaultAggregation,
444 display_name: String,
445 event_name: &'a str,
446}
447
448#[derive(Clone, Deserialize)]
449struct StripeMeter {
450 id: String,
451 event_name: String,
452}
453
454impl StripeMeter {
455 pub fn create(
456 client: &stripe::Client,
457 params: StripeCreateMeterParams,
458 ) -> stripe::Response<Self> {
459 client.post_form("/billing/meters", params)
460 }
461
462 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
463 #[derive(Serialize)]
464 struct Params {
465 #[serde(skip_serializing_if = "Option::is_none")]
466 limit: Option<u64>,
467 }
468
469 client.get_query("/billing/meters", Params { limit: Some(100) })
470 }
471}
472
473#[derive(Deserialize)]
474struct StripeMeterEvent {
475 identifier: String,
476}
477
478impl StripeMeterEvent {
479 pub async fn create(
480 client: &stripe::Client,
481 params: StripeCreateMeterEventParams<'_>,
482 ) -> Result<Self, stripe::StripeError> {
483 let identifier = params.identifier;
484 match client.post_form("/billing/meter_events", params).await {
485 Ok(event) => Ok(event),
486 Err(stripe::StripeError::Stripe(error)) => {
487 if error.http_status == 400
488 && error
489 .message
490 .as_ref()
491 .map_or(false, |message| message.contains(identifier))
492 {
493 Ok(Self {
494 identifier: identifier.to_string(),
495 })
496 } else {
497 Err(stripe::StripeError::Stripe(error))
498 }
499 }
500 Err(error) => Err(error),
501 }
502 }
503}
504
505#[derive(Serialize)]
506struct StripeCreateMeterEventParams<'a> {
507 identifier: &'a str,
508 event_name: &'a str,
509 payload: StripeCreateMeterEventPayload<'a>,
510 timestamp: Option<i64>,
511}
512
513#[derive(Serialize)]
514struct StripeCreateMeterEventPayload<'a> {
515 value: u64,
516 stripe_customer_id: &'a stripe::CustomerId,
517}
518
519fn subscription_contains_price(
520 subscription: &stripe::Subscription,
521 price_id: &stripe::PriceId,
522) -> bool {
523 subscription.items.data.iter().any(|item| {
524 item.price
525 .as_ref()
526 .map_or(false, |price| price.id == *price_id)
527 })
528}