1use std::sync::Arc;
2
3use crate::{Cents, Result, llm};
4use anyhow::Context as _;
5use chrono::{Datelike, Utc};
6use collections::HashMap;
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9
10pub struct StripeBilling {
11 state: RwLock<StripeBillingState>,
12 client: Arc<stripe::Client>,
13}
14
15#[derive(Default)]
16struct StripeBillingState {
17 meters_by_event_name: HashMap<String, StripeMeter>,
18 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
19}
20
21pub struct StripeModel {
22 input_tokens_price: StripeBillingPrice,
23 input_cache_creation_tokens_price: StripeBillingPrice,
24 input_cache_read_tokens_price: StripeBillingPrice,
25 output_tokens_price: StripeBillingPrice,
26}
27
28struct StripeBillingPrice {
29 id: stripe::PriceId,
30 meter_event_name: String,
31}
32
33impl StripeBilling {
34 pub fn new(client: Arc<stripe::Client>) -> Self {
35 Self {
36 client,
37 state: RwLock::default(),
38 }
39 }
40
41 pub async fn initialize(&self) -> Result<()> {
42 log::info!("StripeBilling: initializing");
43
44 let mut state = self.state.write().await;
45
46 let (meters, prices) = futures::try_join!(
47 StripeMeter::list(&self.client),
48 stripe::Price::list(
49 &self.client,
50 &stripe::ListPrices {
51 limit: Some(100),
52 ..Default::default()
53 }
54 )
55 )?;
56
57 for meter in meters.data {
58 state
59 .meters_by_event_name
60 .insert(meter.event_name.clone(), meter);
61 }
62
63 for price in prices.data {
64 if let Some(recurring) = price.recurring {
65 if let Some(meter) = recurring.meter {
66 state.price_ids_by_meter_id.insert(meter, price.id);
67 }
68 }
69 }
70
71 log::info!("StripeBilling: initialized");
72
73 Ok(())
74 }
75
76 pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
77 let input_tokens_price = self
78 .get_or_insert_price(
79 &format!("model_{}/input_tokens", model.id),
80 &format!("{} (Input Tokens)", model.name),
81 Cents::new(model.price_per_million_input_tokens as u32),
82 )
83 .await?;
84 let input_cache_creation_tokens_price = self
85 .get_or_insert_price(
86 &format!("model_{}/input_cache_creation_tokens", model.id),
87 &format!("{} (Input Cache Creation Tokens)", model.name),
88 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
89 )
90 .await?;
91 let input_cache_read_tokens_price = self
92 .get_or_insert_price(
93 &format!("model_{}/input_cache_read_tokens", model.id),
94 &format!("{} (Input Cache Read Tokens)", model.name),
95 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
96 )
97 .await?;
98 let output_tokens_price = self
99 .get_or_insert_price(
100 &format!("model_{}/output_tokens", model.id),
101 &format!("{} (Output Tokens)", model.name),
102 Cents::new(model.price_per_million_output_tokens as u32),
103 )
104 .await?;
105 Ok(StripeModel {
106 input_tokens_price,
107 input_cache_creation_tokens_price,
108 input_cache_read_tokens_price,
109 output_tokens_price,
110 })
111 }
112
113 async fn get_or_insert_price(
114 &self,
115 meter_event_name: &str,
116 price_description: &str,
117 price_per_million_tokens: Cents,
118 ) -> Result<StripeBillingPrice> {
119 // Fast code path when the meter and the price already exist.
120 {
121 let state = self.state.read().await;
122 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
123 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
124 return Ok(StripeBillingPrice {
125 id: price_id.clone(),
126 meter_event_name: meter_event_name.to_string(),
127 });
128 }
129 }
130 }
131
132 let mut state = self.state.write().await;
133 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
134 meter.clone()
135 } else {
136 let meter = StripeMeter::create(
137 &self.client,
138 StripeCreateMeterParams {
139 default_aggregation: DefaultAggregation { formula: "sum" },
140 display_name: price_description.to_string(),
141 event_name: meter_event_name,
142 },
143 )
144 .await?;
145 state
146 .meters_by_event_name
147 .insert(meter_event_name.to_string(), meter.clone());
148 meter
149 };
150
151 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
152 price_id.clone()
153 } else {
154 let price = stripe::Price::create(
155 &self.client,
156 stripe::CreatePrice {
157 active: Some(true),
158 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
159 currency: stripe::Currency::USD,
160 currency_options: None,
161 custom_unit_amount: None,
162 expand: &[],
163 lookup_key: None,
164 metadata: None,
165 nickname: None,
166 product: None,
167 product_data: Some(stripe::CreatePriceProductData {
168 id: None,
169 active: Some(true),
170 metadata: None,
171 name: price_description.to_string(),
172 statement_descriptor: None,
173 tax_code: None,
174 unit_label: None,
175 }),
176 recurring: Some(stripe::CreatePriceRecurring {
177 aggregate_usage: None,
178 interval: stripe::CreatePriceRecurringInterval::Month,
179 interval_count: None,
180 trial_period_days: None,
181 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
182 meter: Some(meter.id.clone()),
183 }),
184 tax_behavior: None,
185 tiers: None,
186 tiers_mode: None,
187 transfer_lookup_key: None,
188 transform_quantity: None,
189 unit_amount: None,
190 unit_amount_decimal: Some(&format!(
191 "{:.12}",
192 price_per_million_tokens.0 as f64 / 1_000_000f64
193 )),
194 },
195 )
196 .await?;
197 state
198 .price_ids_by_meter_id
199 .insert(meter.id, price.id.clone());
200 price.id
201 };
202
203 Ok(StripeBillingPrice {
204 id: price_id,
205 meter_event_name: meter_event_name.to_string(),
206 })
207 }
208
209 pub async fn subscribe_to_model(
210 &self,
211 subscription_id: &stripe::SubscriptionId,
212 model: &StripeModel,
213 ) -> Result<()> {
214 let subscription =
215 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
216
217 let mut items = Vec::new();
218
219 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
220 items.push(stripe::UpdateSubscriptionItems {
221 price: Some(model.input_tokens_price.id.to_string()),
222 ..Default::default()
223 });
224 }
225
226 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
227 {
228 items.push(stripe::UpdateSubscriptionItems {
229 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
230 ..Default::default()
231 });
232 }
233
234 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
235 items.push(stripe::UpdateSubscriptionItems {
236 price: Some(model.input_cache_read_tokens_price.id.to_string()),
237 ..Default::default()
238 });
239 }
240
241 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
242 items.push(stripe::UpdateSubscriptionItems {
243 price: Some(model.output_tokens_price.id.to_string()),
244 ..Default::default()
245 });
246 }
247
248 if !items.is_empty() {
249 items.extend(subscription.items.data.iter().map(|item| {
250 stripe::UpdateSubscriptionItems {
251 id: Some(item.id.to_string()),
252 ..Default::default()
253 }
254 }));
255
256 stripe::Subscription::update(
257 &self.client,
258 subscription_id,
259 stripe::UpdateSubscription {
260 items: Some(items),
261 ..Default::default()
262 },
263 )
264 .await?;
265 }
266
267 Ok(())
268 }
269
270 pub async fn bill_model_usage(
271 &self,
272 customer_id: &stripe::CustomerId,
273 model: &StripeModel,
274 event: &llm::db::billing_event::Model,
275 ) -> Result<()> {
276 let timestamp = Utc::now().timestamp();
277
278 if event.input_tokens > 0 {
279 StripeMeterEvent::create(
280 &self.client,
281 StripeCreateMeterEventParams {
282 identifier: &format!("input_tokens/{}", event.idempotency_key),
283 event_name: &model.input_tokens_price.meter_event_name,
284 payload: StripeCreateMeterEventPayload {
285 value: event.input_tokens as u64,
286 stripe_customer_id: customer_id,
287 },
288 timestamp: Some(timestamp),
289 },
290 )
291 .await?;
292 }
293
294 if event.input_cache_creation_tokens > 0 {
295 StripeMeterEvent::create(
296 &self.client,
297 StripeCreateMeterEventParams {
298 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
299 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
300 payload: StripeCreateMeterEventPayload {
301 value: event.input_cache_creation_tokens as u64,
302 stripe_customer_id: customer_id,
303 },
304 timestamp: Some(timestamp),
305 },
306 )
307 .await?;
308 }
309
310 if event.input_cache_read_tokens > 0 {
311 StripeMeterEvent::create(
312 &self.client,
313 StripeCreateMeterEventParams {
314 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
315 event_name: &model.input_cache_read_tokens_price.meter_event_name,
316 payload: StripeCreateMeterEventPayload {
317 value: event.input_cache_read_tokens as u64,
318 stripe_customer_id: customer_id,
319 },
320 timestamp: Some(timestamp),
321 },
322 )
323 .await?;
324 }
325
326 if event.output_tokens > 0 {
327 StripeMeterEvent::create(
328 &self.client,
329 StripeCreateMeterEventParams {
330 identifier: &format!("output_tokens/{}", event.idempotency_key),
331 event_name: &model.output_tokens_price.meter_event_name,
332 payload: StripeCreateMeterEventPayload {
333 value: event.output_tokens as u64,
334 stripe_customer_id: customer_id,
335 },
336 timestamp: Some(timestamp),
337 },
338 )
339 .await?;
340 }
341
342 Ok(())
343 }
344
345 pub async fn checkout(
346 &self,
347 customer_id: stripe::CustomerId,
348 github_login: &str,
349 model: &StripeModel,
350 success_url: &str,
351 ) -> Result<String> {
352 let first_of_next_month = Utc::now()
353 .checked_add_months(chrono::Months::new(1))
354 .unwrap()
355 .with_day(1)
356 .unwrap();
357
358 let mut params = stripe::CreateCheckoutSession::new();
359 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
360 params.customer = Some(customer_id);
361 params.client_reference_id = Some(github_login);
362 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
363 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
364 ..Default::default()
365 });
366 params.line_items = Some(
367 [
368 &model.input_tokens_price.id,
369 &model.input_cache_creation_tokens_price.id,
370 &model.input_cache_read_tokens_price.id,
371 &model.output_tokens_price.id,
372 ]
373 .into_iter()
374 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
375 price: Some(price_id.to_string()),
376 ..Default::default()
377 })
378 .collect(),
379 );
380 params.success_url = Some(success_url);
381
382 let session = stripe::CheckoutSession::create(&self.client, params).await?;
383 Ok(session.url.context("no checkout session URL")?)
384 }
385}
386
387#[derive(Serialize)]
388struct DefaultAggregation {
389 formula: &'static str,
390}
391
392#[derive(Serialize)]
393struct StripeCreateMeterParams<'a> {
394 default_aggregation: DefaultAggregation,
395 display_name: String,
396 event_name: &'a str,
397}
398
399#[derive(Clone, Deserialize)]
400struct StripeMeter {
401 id: String,
402 event_name: String,
403}
404
405impl StripeMeter {
406 pub fn create(
407 client: &stripe::Client,
408 params: StripeCreateMeterParams,
409 ) -> stripe::Response<Self> {
410 client.post_form("/billing/meters", params)
411 }
412
413 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
414 #[derive(Serialize)]
415 struct Params {
416 #[serde(skip_serializing_if = "Option::is_none")]
417 limit: Option<u64>,
418 }
419
420 client.get_query("/billing/meters", Params { limit: Some(100) })
421 }
422}
423
424#[derive(Deserialize)]
425struct StripeMeterEvent {
426 identifier: String,
427}
428
429impl StripeMeterEvent {
430 pub async fn create(
431 client: &stripe::Client,
432 params: StripeCreateMeterEventParams<'_>,
433 ) -> Result<Self, stripe::StripeError> {
434 let identifier = params.identifier;
435 match client.post_form("/billing/meter_events", params).await {
436 Ok(event) => Ok(event),
437 Err(stripe::StripeError::Stripe(error)) => {
438 if error.http_status == 400
439 && error
440 .message
441 .as_ref()
442 .map_or(false, |message| message.contains(identifier))
443 {
444 Ok(Self {
445 identifier: identifier.to_string(),
446 })
447 } else {
448 Err(stripe::StripeError::Stripe(error))
449 }
450 }
451 Err(error) => Err(error),
452 }
453 }
454}
455
456#[derive(Serialize)]
457struct StripeCreateMeterEventParams<'a> {
458 identifier: &'a str,
459 event_name: &'a str,
460 payload: StripeCreateMeterEventPayload<'a>,
461 timestamp: Option<i64>,
462}
463
464#[derive(Serialize)]
465struct StripeCreateMeterEventPayload<'a> {
466 value: u64,
467 stripe_customer_id: &'a stripe::CustomerId,
468}
469
470fn subscription_contains_price(
471 subscription: &stripe::Subscription,
472 price_id: &stripe::PriceId,
473) -> bool {
474 subscription.items.data.iter().any(|item| {
475 item.price
476 .as_ref()
477 .map_or(false, |price| price.id == *price_id)
478 })
479}