1use std::sync::Arc;
2
3use crate::{llm, Cents, Result};
4use anyhow::Context;
5use chrono::Utc;
6use collections::HashMap;
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9
10pub struct StripeBilling {
11 meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
12 price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
13 client: Arc<stripe::Client>,
14}
15
16pub struct StripeModel {
17 input_tokens_price: StripeBillingPrice,
18 input_cache_creation_tokens_price: StripeBillingPrice,
19 input_cache_read_tokens_price: StripeBillingPrice,
20 output_tokens_price: StripeBillingPrice,
21}
22
23struct StripeBillingPrice {
24 id: stripe::PriceId,
25 meter_event_name: String,
26}
27
28impl StripeBilling {
29 pub fn new(client: Arc<stripe::Client>) -> Self {
30 Self {
31 client,
32 meters_by_event_name: RwLock::new(HashMap::default()),
33 price_ids_by_meter_id: RwLock::new(HashMap::default()),
34 }
35 }
36
37 pub async fn initialize(&self) -> Result<()> {
38 log::info!("initializing StripeBilling");
39
40 {
41 let meters = StripeMeter::list(&self.client).await?.data;
42 let mut meters_by_event_name = self.meters_by_event_name.write().await;
43 for meter in meters {
44 meters_by_event_name.insert(meter.event_name.clone(), meter);
45 }
46 }
47
48 {
49 let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
50 .await?
51 .data;
52 let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
53 for price in prices {
54 if let Some(recurring) = price.recurring {
55 if let Some(meter) = recurring.meter {
56 price_ids_by_meter_id.insert(meter, price.id);
57 }
58 }
59 }
60 }
61
62 Ok(())
63 }
64
65 pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
66 let input_tokens_price = self
67 .get_or_insert_price(
68 &format!("model_{}/input_tokens", model.id),
69 &format!("{} (Input Tokens)", model.name),
70 Cents::new(model.price_per_million_input_tokens as u32),
71 )
72 .await?;
73 let input_cache_creation_tokens_price = self
74 .get_or_insert_price(
75 &format!("model_{}/input_cache_creation_tokens", model.id),
76 &format!("{} (Input Cache Creation Tokens)", model.name),
77 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
78 )
79 .await?;
80 let input_cache_read_tokens_price = self
81 .get_or_insert_price(
82 &format!("model_{}/input_cache_read_tokens", model.id),
83 &format!("{} (Input Cache Read Tokens)", model.name),
84 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
85 )
86 .await?;
87 let output_tokens_price = self
88 .get_or_insert_price(
89 &format!("model_{}/output_tokens", model.id),
90 &format!("{} (Output Tokens)", model.name),
91 Cents::new(model.price_per_million_output_tokens as u32),
92 )
93 .await?;
94 Ok(StripeModel {
95 input_tokens_price,
96 input_cache_creation_tokens_price,
97 input_cache_read_tokens_price,
98 output_tokens_price,
99 })
100 }
101
102 async fn get_or_insert_price(
103 &self,
104 meter_event_name: &str,
105 price_description: &str,
106 price_per_million_tokens: Cents,
107 ) -> Result<StripeBillingPrice> {
108 let meter =
109 if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
110 meter.clone()
111 } else {
112 let meter = StripeMeter::create(
113 &self.client,
114 StripeCreateMeterParams {
115 default_aggregation: DefaultAggregation { formula: "sum" },
116 display_name: price_description.to_string(),
117 event_name: meter_event_name,
118 },
119 )
120 .await?;
121 self.meters_by_event_name
122 .write()
123 .await
124 .insert(meter_event_name.to_string(), meter.clone());
125 meter
126 };
127
128 let price_id =
129 if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
130 price_id.clone()
131 } else {
132 let price = stripe::Price::create(
133 &self.client,
134 stripe::CreatePrice {
135 active: Some(true),
136 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
137 currency: stripe::Currency::USD,
138 currency_options: None,
139 custom_unit_amount: None,
140 expand: &[],
141 lookup_key: None,
142 metadata: None,
143 nickname: None,
144 product: None,
145 product_data: Some(stripe::CreatePriceProductData {
146 id: None,
147 active: Some(true),
148 metadata: None,
149 name: price_description.to_string(),
150 statement_descriptor: None,
151 tax_code: None,
152 unit_label: None,
153 }),
154 recurring: Some(stripe::CreatePriceRecurring {
155 aggregate_usage: None,
156 interval: stripe::CreatePriceRecurringInterval::Month,
157 interval_count: None,
158 trial_period_days: None,
159 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
160 meter: Some(meter.id.clone()),
161 }),
162 tax_behavior: None,
163 tiers: None,
164 tiers_mode: None,
165 transfer_lookup_key: None,
166 transform_quantity: None,
167 unit_amount: None,
168 unit_amount_decimal: Some(&format!(
169 "{:.12}",
170 price_per_million_tokens.0 as f64 / 1_000_000f64
171 )),
172 },
173 )
174 .await?;
175 self.price_ids_by_meter_id
176 .write()
177 .await
178 .insert(meter.id, price.id.clone());
179 price.id
180 };
181
182 Ok(StripeBillingPrice {
183 id: price_id,
184 meter_event_name: meter_event_name.to_string(),
185 })
186 }
187
188 pub async fn subscribe_to_model(
189 &self,
190 subscription_id: &stripe::SubscriptionId,
191 model: &StripeModel,
192 ) -> Result<()> {
193 let subscription =
194 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
195
196 let mut items = Vec::new();
197
198 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
199 items.push(stripe::UpdateSubscriptionItems {
200 price: Some(model.input_tokens_price.id.to_string()),
201 ..Default::default()
202 });
203 }
204
205 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
206 {
207 items.push(stripe::UpdateSubscriptionItems {
208 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
209 ..Default::default()
210 });
211 }
212
213 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
214 items.push(stripe::UpdateSubscriptionItems {
215 price: Some(model.input_cache_read_tokens_price.id.to_string()),
216 ..Default::default()
217 });
218 }
219
220 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
221 items.push(stripe::UpdateSubscriptionItems {
222 price: Some(model.output_tokens_price.id.to_string()),
223 ..Default::default()
224 });
225 }
226
227 if !items.is_empty() {
228 items.extend(subscription.items.data.iter().map(|item| {
229 stripe::UpdateSubscriptionItems {
230 id: Some(item.id.to_string()),
231 ..Default::default()
232 }
233 }));
234
235 stripe::Subscription::update(
236 &self.client,
237 subscription_id,
238 stripe::UpdateSubscription {
239 items: Some(items),
240 ..Default::default()
241 },
242 )
243 .await?;
244 }
245
246 Ok(())
247 }
248
249 pub async fn bill_model_usage(
250 &self,
251 customer_id: &stripe::CustomerId,
252 model: &StripeModel,
253 event: &llm::db::billing_event::Model,
254 ) -> Result<()> {
255 let timestamp = Utc::now().timestamp();
256
257 if event.input_tokens > 0 {
258 StripeMeterEvent::create(
259 &self.client,
260 StripeCreateMeterEventParams {
261 identifier: &format!("input_tokens/{}", event.idempotency_key),
262 event_name: &model.input_tokens_price.meter_event_name,
263 payload: StripeCreateMeterEventPayload {
264 value: event.input_tokens as u64,
265 stripe_customer_id: customer_id,
266 },
267 timestamp: Some(timestamp),
268 },
269 )
270 .await?;
271 }
272
273 if event.input_cache_creation_tokens > 0 {
274 StripeMeterEvent::create(
275 &self.client,
276 StripeCreateMeterEventParams {
277 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
278 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
279 payload: StripeCreateMeterEventPayload {
280 value: event.input_cache_creation_tokens as u64,
281 stripe_customer_id: customer_id,
282 },
283 timestamp: Some(timestamp),
284 },
285 )
286 .await?;
287 }
288
289 if event.input_cache_read_tokens > 0 {
290 StripeMeterEvent::create(
291 &self.client,
292 StripeCreateMeterEventParams {
293 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
294 event_name: &model.input_cache_read_tokens_price.meter_event_name,
295 payload: StripeCreateMeterEventPayload {
296 value: event.input_cache_read_tokens as u64,
297 stripe_customer_id: customer_id,
298 },
299 timestamp: Some(timestamp),
300 },
301 )
302 .await?;
303 }
304
305 if event.output_tokens > 0 {
306 StripeMeterEvent::create(
307 &self.client,
308 StripeCreateMeterEventParams {
309 identifier: &format!("output_tokens/{}", event.idempotency_key),
310 event_name: &model.output_tokens_price.meter_event_name,
311 payload: StripeCreateMeterEventPayload {
312 value: event.output_tokens as u64,
313 stripe_customer_id: customer_id,
314 },
315 timestamp: Some(timestamp),
316 },
317 )
318 .await?;
319 }
320
321 Ok(())
322 }
323
324 pub async fn checkout(
325 &self,
326 customer_id: stripe::CustomerId,
327 github_login: &str,
328 model: &StripeModel,
329 success_url: &str,
330 ) -> Result<String> {
331 let mut params = stripe::CreateCheckoutSession::new();
332 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
333 params.customer = Some(customer_id);
334 params.client_reference_id = Some(github_login);
335 params.line_items = Some(
336 [
337 &model.input_tokens_price.id,
338 &model.input_cache_creation_tokens_price.id,
339 &model.input_cache_read_tokens_price.id,
340 &model.output_tokens_price.id,
341 ]
342 .into_iter()
343 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
344 price: Some(price_id.to_string()),
345 ..Default::default()
346 })
347 .collect(),
348 );
349 params.success_url = Some(success_url);
350
351 let session = stripe::CheckoutSession::create(&self.client, params).await?;
352 Ok(session.url.context("no checkout session URL")?)
353 }
354}
355
356#[derive(Serialize)]
357struct DefaultAggregation {
358 formula: &'static str,
359}
360
361#[derive(Serialize)]
362struct StripeCreateMeterParams<'a> {
363 default_aggregation: DefaultAggregation,
364 display_name: String,
365 event_name: &'a str,
366}
367
368#[derive(Clone, Deserialize)]
369struct StripeMeter {
370 id: String,
371 event_name: String,
372}
373
374impl StripeMeter {
375 pub fn create(
376 client: &stripe::Client,
377 params: StripeCreateMeterParams,
378 ) -> stripe::Response<Self> {
379 client.post_form("/billing/meters", params)
380 }
381
382 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
383 #[derive(Serialize)]
384 struct Params {}
385
386 client.get_query("/billing/meters", Params {})
387 }
388}
389
390#[derive(Deserialize)]
391struct StripeMeterEvent {
392 identifier: String,
393}
394
395impl StripeMeterEvent {
396 pub async fn create(
397 client: &stripe::Client,
398 params: StripeCreateMeterEventParams<'_>,
399 ) -> Result<Self, stripe::StripeError> {
400 let identifier = params.identifier;
401 match client.post_form("/billing/meter_events", params).await {
402 Ok(event) => Ok(event),
403 Err(stripe::StripeError::Stripe(error)) => {
404 if error.http_status == 400
405 && error
406 .message
407 .as_ref()
408 .map_or(false, |message| message.contains(identifier))
409 {
410 Ok(Self {
411 identifier: identifier.to_string(),
412 })
413 } else {
414 Err(stripe::StripeError::Stripe(error))
415 }
416 }
417 Err(error) => Err(error),
418 }
419 }
420}
421
422#[derive(Serialize)]
423struct StripeCreateMeterEventParams<'a> {
424 identifier: &'a str,
425 event_name: &'a str,
426 payload: StripeCreateMeterEventPayload<'a>,
427 timestamp: Option<i64>,
428}
429
430#[derive(Serialize)]
431struct StripeCreateMeterEventPayload<'a> {
432 value: u64,
433 stripe_customer_id: &'a stripe::CustomerId,
434}
435
436fn subscription_contains_price(
437 subscription: &stripe::Subscription,
438 price_id: &stripe::PriceId,
439) -> bool {
440 subscription.items.data.iter().any(|item| {
441 item.price
442 .as_ref()
443 .map_or(false, |price| price.id == *price_id)
444 })
445}