1use std::sync::Arc;
2
3use crate::{llm, Cents, Result};
4use anyhow::Context;
5use chrono::Utc;
6use collections::HashMap;
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9
10pub struct StripeBilling {
11 state: RwLock<StripeBillingState>,
12 client: Arc<stripe::Client>,
13}
14
15#[derive(Default)]
16struct StripeBillingState {
17 meters_by_event_name: HashMap<String, StripeMeter>,
18 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
19}
20
21pub struct StripeModel {
22 input_tokens_price: StripeBillingPrice,
23 input_cache_creation_tokens_price: StripeBillingPrice,
24 input_cache_read_tokens_price: StripeBillingPrice,
25 output_tokens_price: StripeBillingPrice,
26}
27
28struct StripeBillingPrice {
29 id: stripe::PriceId,
30 meter_event_name: String,
31}
32
33impl StripeBilling {
34 pub fn new(client: Arc<stripe::Client>) -> Self {
35 Self {
36 client,
37 state: RwLock::default(),
38 }
39 }
40
41 pub async fn initialize(&self) -> Result<()> {
42 log::info!("StripeBilling: initializing");
43
44 let mut state = self.state.write().await;
45
46 let (meters, prices) = futures::try_join!(
47 StripeMeter::list(&self.client),
48 stripe::Price::list(
49 &self.client,
50 &stripe::ListPrices {
51 limit: Some(100),
52 ..Default::default()
53 }
54 )
55 )?;
56
57 for meter in meters.data {
58 state
59 .meters_by_event_name
60 .insert(meter.event_name.clone(), meter);
61 }
62
63 for price in prices.data {
64 if let Some(recurring) = price.recurring {
65 if let Some(meter) = recurring.meter {
66 state.price_ids_by_meter_id.insert(meter, price.id);
67 }
68 }
69 }
70
71 log::info!("StripeBilling: initialized");
72
73 Ok(())
74 }
75
76 pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
77 let input_tokens_price = self
78 .get_or_insert_price(
79 &format!("model_{}/input_tokens", model.id),
80 &format!("{} (Input Tokens)", model.name),
81 Cents::new(model.price_per_million_input_tokens as u32),
82 )
83 .await?;
84 let input_cache_creation_tokens_price = self
85 .get_or_insert_price(
86 &format!("model_{}/input_cache_creation_tokens", model.id),
87 &format!("{} (Input Cache Creation Tokens)", model.name),
88 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
89 )
90 .await?;
91 let input_cache_read_tokens_price = self
92 .get_or_insert_price(
93 &format!("model_{}/input_cache_read_tokens", model.id),
94 &format!("{} (Input Cache Read Tokens)", model.name),
95 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
96 )
97 .await?;
98 let output_tokens_price = self
99 .get_or_insert_price(
100 &format!("model_{}/output_tokens", model.id),
101 &format!("{} (Output Tokens)", model.name),
102 Cents::new(model.price_per_million_output_tokens as u32),
103 )
104 .await?;
105 Ok(StripeModel {
106 input_tokens_price,
107 input_cache_creation_tokens_price,
108 input_cache_read_tokens_price,
109 output_tokens_price,
110 })
111 }
112
113 async fn get_or_insert_price(
114 &self,
115 meter_event_name: &str,
116 price_description: &str,
117 price_per_million_tokens: Cents,
118 ) -> Result<StripeBillingPrice> {
119 // Fast code path when the meter and the price already exist.
120 {
121 let state = self.state.read().await;
122 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
123 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
124 return Ok(StripeBillingPrice {
125 id: price_id.clone(),
126 meter_event_name: meter_event_name.to_string(),
127 });
128 }
129 }
130 }
131
132 let mut state = self.state.write().await;
133 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
134 meter.clone()
135 } else {
136 let meter = StripeMeter::create(
137 &self.client,
138 StripeCreateMeterParams {
139 default_aggregation: DefaultAggregation { formula: "sum" },
140 display_name: price_description.to_string(),
141 event_name: meter_event_name,
142 },
143 )
144 .await?;
145 state
146 .meters_by_event_name
147 .insert(meter_event_name.to_string(), meter.clone());
148 meter
149 };
150
151 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
152 price_id.clone()
153 } else {
154 let price = stripe::Price::create(
155 &self.client,
156 stripe::CreatePrice {
157 active: Some(true),
158 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
159 currency: stripe::Currency::USD,
160 currency_options: None,
161 custom_unit_amount: None,
162 expand: &[],
163 lookup_key: None,
164 metadata: None,
165 nickname: None,
166 product: None,
167 product_data: Some(stripe::CreatePriceProductData {
168 id: None,
169 active: Some(true),
170 metadata: None,
171 name: price_description.to_string(),
172 statement_descriptor: None,
173 tax_code: None,
174 unit_label: None,
175 }),
176 recurring: Some(stripe::CreatePriceRecurring {
177 aggregate_usage: None,
178 interval: stripe::CreatePriceRecurringInterval::Month,
179 interval_count: None,
180 trial_period_days: None,
181 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
182 meter: Some(meter.id.clone()),
183 }),
184 tax_behavior: None,
185 tiers: None,
186 tiers_mode: None,
187 transfer_lookup_key: None,
188 transform_quantity: None,
189 unit_amount: None,
190 unit_amount_decimal: Some(&format!(
191 "{:.12}",
192 price_per_million_tokens.0 as f64 / 1_000_000f64
193 )),
194 },
195 )
196 .await?;
197 state
198 .price_ids_by_meter_id
199 .insert(meter.id, price.id.clone());
200 price.id
201 };
202
203 Ok(StripeBillingPrice {
204 id: price_id,
205 meter_event_name: meter_event_name.to_string(),
206 })
207 }
208
209 pub async fn subscribe_to_model(
210 &self,
211 subscription_id: &stripe::SubscriptionId,
212 model: &StripeModel,
213 ) -> Result<()> {
214 let subscription =
215 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
216
217 let mut items = Vec::new();
218
219 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
220 items.push(stripe::UpdateSubscriptionItems {
221 price: Some(model.input_tokens_price.id.to_string()),
222 ..Default::default()
223 });
224 }
225
226 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
227 {
228 items.push(stripe::UpdateSubscriptionItems {
229 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
230 ..Default::default()
231 });
232 }
233
234 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
235 items.push(stripe::UpdateSubscriptionItems {
236 price: Some(model.input_cache_read_tokens_price.id.to_string()),
237 ..Default::default()
238 });
239 }
240
241 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
242 items.push(stripe::UpdateSubscriptionItems {
243 price: Some(model.output_tokens_price.id.to_string()),
244 ..Default::default()
245 });
246 }
247
248 if !items.is_empty() {
249 items.extend(subscription.items.data.iter().map(|item| {
250 stripe::UpdateSubscriptionItems {
251 id: Some(item.id.to_string()),
252 ..Default::default()
253 }
254 }));
255
256 stripe::Subscription::update(
257 &self.client,
258 subscription_id,
259 stripe::UpdateSubscription {
260 items: Some(items),
261 ..Default::default()
262 },
263 )
264 .await?;
265 }
266
267 Ok(())
268 }
269
270 pub async fn bill_model_usage(
271 &self,
272 customer_id: &stripe::CustomerId,
273 model: &StripeModel,
274 event: &llm::db::billing_event::Model,
275 ) -> Result<()> {
276 let timestamp = Utc::now().timestamp();
277
278 if event.input_tokens > 0 {
279 StripeMeterEvent::create(
280 &self.client,
281 StripeCreateMeterEventParams {
282 identifier: &format!("input_tokens/{}", event.idempotency_key),
283 event_name: &model.input_tokens_price.meter_event_name,
284 payload: StripeCreateMeterEventPayload {
285 value: event.input_tokens as u64,
286 stripe_customer_id: customer_id,
287 },
288 timestamp: Some(timestamp),
289 },
290 )
291 .await?;
292 }
293
294 if event.input_cache_creation_tokens > 0 {
295 StripeMeterEvent::create(
296 &self.client,
297 StripeCreateMeterEventParams {
298 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
299 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
300 payload: StripeCreateMeterEventPayload {
301 value: event.input_cache_creation_tokens as u64,
302 stripe_customer_id: customer_id,
303 },
304 timestamp: Some(timestamp),
305 },
306 )
307 .await?;
308 }
309
310 if event.input_cache_read_tokens > 0 {
311 StripeMeterEvent::create(
312 &self.client,
313 StripeCreateMeterEventParams {
314 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
315 event_name: &model.input_cache_read_tokens_price.meter_event_name,
316 payload: StripeCreateMeterEventPayload {
317 value: event.input_cache_read_tokens as u64,
318 stripe_customer_id: customer_id,
319 },
320 timestamp: Some(timestamp),
321 },
322 )
323 .await?;
324 }
325
326 if event.output_tokens > 0 {
327 StripeMeterEvent::create(
328 &self.client,
329 StripeCreateMeterEventParams {
330 identifier: &format!("output_tokens/{}", event.idempotency_key),
331 event_name: &model.output_tokens_price.meter_event_name,
332 payload: StripeCreateMeterEventPayload {
333 value: event.output_tokens as u64,
334 stripe_customer_id: customer_id,
335 },
336 timestamp: Some(timestamp),
337 },
338 )
339 .await?;
340 }
341
342 Ok(())
343 }
344
345 pub async fn checkout(
346 &self,
347 customer_id: stripe::CustomerId,
348 github_login: &str,
349 model: &StripeModel,
350 success_url: &str,
351 ) -> Result<String> {
352 let mut params = stripe::CreateCheckoutSession::new();
353 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
354 params.customer = Some(customer_id);
355 params.client_reference_id = Some(github_login);
356 params.line_items = Some(
357 [
358 &model.input_tokens_price.id,
359 &model.input_cache_creation_tokens_price.id,
360 &model.input_cache_read_tokens_price.id,
361 &model.output_tokens_price.id,
362 ]
363 .into_iter()
364 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
365 price: Some(price_id.to_string()),
366 ..Default::default()
367 })
368 .collect(),
369 );
370 params.success_url = Some(success_url);
371
372 let session = stripe::CheckoutSession::create(&self.client, params).await?;
373 Ok(session.url.context("no checkout session URL")?)
374 }
375}
376
377#[derive(Serialize)]
378struct DefaultAggregation {
379 formula: &'static str,
380}
381
382#[derive(Serialize)]
383struct StripeCreateMeterParams<'a> {
384 default_aggregation: DefaultAggregation,
385 display_name: String,
386 event_name: &'a str,
387}
388
389#[derive(Clone, Deserialize)]
390struct StripeMeter {
391 id: String,
392 event_name: String,
393}
394
395impl StripeMeter {
396 pub fn create(
397 client: &stripe::Client,
398 params: StripeCreateMeterParams,
399 ) -> stripe::Response<Self> {
400 client.post_form("/billing/meters", params)
401 }
402
403 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
404 #[derive(Serialize)]
405 struct Params {
406 #[serde(skip_serializing_if = "Option::is_none")]
407 limit: Option<u64>,
408 }
409
410 client.get_query("/billing/meters", Params { limit: Some(100) })
411 }
412}
413
414#[derive(Deserialize)]
415struct StripeMeterEvent {
416 identifier: String,
417}
418
419impl StripeMeterEvent {
420 pub async fn create(
421 client: &stripe::Client,
422 params: StripeCreateMeterEventParams<'_>,
423 ) -> Result<Self, stripe::StripeError> {
424 let identifier = params.identifier;
425 match client.post_form("/billing/meter_events", params).await {
426 Ok(event) => Ok(event),
427 Err(stripe::StripeError::Stripe(error)) => {
428 if error.http_status == 400
429 && error
430 .message
431 .as_ref()
432 .map_or(false, |message| message.contains(identifier))
433 {
434 Ok(Self {
435 identifier: identifier.to_string(),
436 })
437 } else {
438 Err(stripe::StripeError::Stripe(error))
439 }
440 }
441 Err(error) => Err(error),
442 }
443 }
444}
445
446#[derive(Serialize)]
447struct StripeCreateMeterEventParams<'a> {
448 identifier: &'a str,
449 event_name: &'a str,
450 payload: StripeCreateMeterEventPayload<'a>,
451 timestamp: Option<i64>,
452}
453
454#[derive(Serialize)]
455struct StripeCreateMeterEventPayload<'a> {
456 value: u64,
457 stripe_customer_id: &'a stripe::CustomerId,
458}
459
460fn subscription_contains_price(
461 subscription: &stripe::Subscription,
462 price_id: &stripe::PriceId,
463) -> bool {
464 subscription.items.data.iter().any(|item| {
465 item.price
466 .as_ref()
467 .map_or(false, |price| price.id == *price_id)
468 })
469}