1use std::sync::Arc;
2
3use crate::{Cents, Result, llm};
4use anyhow::Context as _;
5use chrono::{Datelike, Utc};
6use collections::HashMap;
7use serde::{Deserialize, Serialize};
8use stripe::PriceId;
9use tokio::sync::RwLock;
10
11pub struct StripeBilling {
12 state: RwLock<StripeBillingState>,
13 client: Arc<stripe::Client>,
14}
15
16#[derive(Default)]
17struct StripeBillingState {
18 meters_by_event_name: HashMap<String, StripeMeter>,
19 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
20}
21
22pub struct StripeModel {
23 input_tokens_price: StripeBillingPrice,
24 input_cache_creation_tokens_price: StripeBillingPrice,
25 input_cache_read_tokens_price: StripeBillingPrice,
26 output_tokens_price: StripeBillingPrice,
27}
28
29struct StripeBillingPrice {
30 id: stripe::PriceId,
31 meter_event_name: String,
32}
33
34impl StripeBilling {
35 pub fn new(client: Arc<stripe::Client>) -> Self {
36 Self {
37 client,
38 state: RwLock::default(),
39 }
40 }
41
42 pub async fn initialize(&self) -> Result<()> {
43 log::info!("StripeBilling: initializing");
44
45 let mut state = self.state.write().await;
46
47 let (meters, prices) = futures::try_join!(
48 StripeMeter::list(&self.client),
49 stripe::Price::list(
50 &self.client,
51 &stripe::ListPrices {
52 limit: Some(100),
53 ..Default::default()
54 }
55 )
56 )?;
57
58 for meter in meters.data {
59 state
60 .meters_by_event_name
61 .insert(meter.event_name.clone(), meter);
62 }
63
64 for price in prices.data {
65 if let Some(recurring) = price.recurring {
66 if let Some(meter) = recurring.meter {
67 state.price_ids_by_meter_id.insert(meter, price.id);
68 }
69 }
70 }
71
72 log::info!("StripeBilling: initialized");
73
74 Ok(())
75 }
76
77 pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
78 let input_tokens_price = self
79 .get_or_insert_price(
80 &format!("model_{}/input_tokens", model.id),
81 &format!("{} (Input Tokens)", model.name),
82 Cents::new(model.price_per_million_input_tokens as u32),
83 )
84 .await?;
85 let input_cache_creation_tokens_price = self
86 .get_or_insert_price(
87 &format!("model_{}/input_cache_creation_tokens", model.id),
88 &format!("{} (Input Cache Creation Tokens)", model.name),
89 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
90 )
91 .await?;
92 let input_cache_read_tokens_price = self
93 .get_or_insert_price(
94 &format!("model_{}/input_cache_read_tokens", model.id),
95 &format!("{} (Input Cache Read Tokens)", model.name),
96 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
97 )
98 .await?;
99 let output_tokens_price = self
100 .get_or_insert_price(
101 &format!("model_{}/output_tokens", model.id),
102 &format!("{} (Output Tokens)", model.name),
103 Cents::new(model.price_per_million_output_tokens as u32),
104 )
105 .await?;
106 Ok(StripeModel {
107 input_tokens_price,
108 input_cache_creation_tokens_price,
109 input_cache_read_tokens_price,
110 output_tokens_price,
111 })
112 }
113
114 async fn get_or_insert_price(
115 &self,
116 meter_event_name: &str,
117 price_description: &str,
118 price_per_million_tokens: Cents,
119 ) -> Result<StripeBillingPrice> {
120 // Fast code path when the meter and the price already exist.
121 {
122 let state = self.state.read().await;
123 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
124 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
125 return Ok(StripeBillingPrice {
126 id: price_id.clone(),
127 meter_event_name: meter_event_name.to_string(),
128 });
129 }
130 }
131 }
132
133 let mut state = self.state.write().await;
134 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
135 meter.clone()
136 } else {
137 let meter = StripeMeter::create(
138 &self.client,
139 StripeCreateMeterParams {
140 default_aggregation: DefaultAggregation { formula: "sum" },
141 display_name: price_description.to_string(),
142 event_name: meter_event_name,
143 },
144 )
145 .await?;
146 state
147 .meters_by_event_name
148 .insert(meter_event_name.to_string(), meter.clone());
149 meter
150 };
151
152 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
153 price_id.clone()
154 } else {
155 let price = stripe::Price::create(
156 &self.client,
157 stripe::CreatePrice {
158 active: Some(true),
159 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
160 currency: stripe::Currency::USD,
161 currency_options: None,
162 custom_unit_amount: None,
163 expand: &[],
164 lookup_key: None,
165 metadata: None,
166 nickname: None,
167 product: None,
168 product_data: Some(stripe::CreatePriceProductData {
169 id: None,
170 active: Some(true),
171 metadata: None,
172 name: price_description.to_string(),
173 statement_descriptor: None,
174 tax_code: None,
175 unit_label: None,
176 }),
177 recurring: Some(stripe::CreatePriceRecurring {
178 aggregate_usage: None,
179 interval: stripe::CreatePriceRecurringInterval::Month,
180 interval_count: None,
181 trial_period_days: None,
182 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
183 meter: Some(meter.id.clone()),
184 }),
185 tax_behavior: None,
186 tiers: None,
187 tiers_mode: None,
188 transfer_lookup_key: None,
189 transform_quantity: None,
190 unit_amount: None,
191 unit_amount_decimal: Some(&format!(
192 "{:.12}",
193 price_per_million_tokens.0 as f64 / 1_000_000f64
194 )),
195 },
196 )
197 .await?;
198 state
199 .price_ids_by_meter_id
200 .insert(meter.id, price.id.clone());
201 price.id
202 };
203
204 Ok(StripeBillingPrice {
205 id: price_id,
206 meter_event_name: meter_event_name.to_string(),
207 })
208 }
209
210 pub async fn subscribe_to_model(
211 &self,
212 subscription_id: &stripe::SubscriptionId,
213 model: &StripeModel,
214 ) -> Result<()> {
215 let subscription =
216 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
217
218 let mut items = Vec::new();
219
220 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
221 items.push(stripe::UpdateSubscriptionItems {
222 price: Some(model.input_tokens_price.id.to_string()),
223 ..Default::default()
224 });
225 }
226
227 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
228 {
229 items.push(stripe::UpdateSubscriptionItems {
230 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
231 ..Default::default()
232 });
233 }
234
235 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
236 items.push(stripe::UpdateSubscriptionItems {
237 price: Some(model.input_cache_read_tokens_price.id.to_string()),
238 ..Default::default()
239 });
240 }
241
242 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
243 items.push(stripe::UpdateSubscriptionItems {
244 price: Some(model.output_tokens_price.id.to_string()),
245 ..Default::default()
246 });
247 }
248
249 if !items.is_empty() {
250 items.extend(subscription.items.data.iter().map(|item| {
251 stripe::UpdateSubscriptionItems {
252 id: Some(item.id.to_string()),
253 ..Default::default()
254 }
255 }));
256
257 stripe::Subscription::update(
258 &self.client,
259 subscription_id,
260 stripe::UpdateSubscription {
261 items: Some(items),
262 ..Default::default()
263 },
264 )
265 .await?;
266 }
267
268 Ok(())
269 }
270
271 pub async fn bill_model_token_usage(
272 &self,
273 customer_id: &stripe::CustomerId,
274 model: &StripeModel,
275 event: &llm::db::billing_event::Model,
276 ) -> Result<()> {
277 let timestamp = Utc::now().timestamp();
278
279 if event.input_tokens > 0 {
280 StripeMeterEvent::create(
281 &self.client,
282 StripeCreateMeterEventParams {
283 identifier: &format!("input_tokens/{}", event.idempotency_key),
284 event_name: &model.input_tokens_price.meter_event_name,
285 payload: StripeCreateMeterEventPayload {
286 value: event.input_tokens as u64,
287 stripe_customer_id: customer_id,
288 },
289 timestamp: Some(timestamp),
290 },
291 )
292 .await?;
293 }
294
295 if event.input_cache_creation_tokens > 0 {
296 StripeMeterEvent::create(
297 &self.client,
298 StripeCreateMeterEventParams {
299 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
300 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
301 payload: StripeCreateMeterEventPayload {
302 value: event.input_cache_creation_tokens as u64,
303 stripe_customer_id: customer_id,
304 },
305 timestamp: Some(timestamp),
306 },
307 )
308 .await?;
309 }
310
311 if event.input_cache_read_tokens > 0 {
312 StripeMeterEvent::create(
313 &self.client,
314 StripeCreateMeterEventParams {
315 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
316 event_name: &model.input_cache_read_tokens_price.meter_event_name,
317 payload: StripeCreateMeterEventPayload {
318 value: event.input_cache_read_tokens as u64,
319 stripe_customer_id: customer_id,
320 },
321 timestamp: Some(timestamp),
322 },
323 )
324 .await?;
325 }
326
327 if event.output_tokens > 0 {
328 StripeMeterEvent::create(
329 &self.client,
330 StripeCreateMeterEventParams {
331 identifier: &format!("output_tokens/{}", event.idempotency_key),
332 event_name: &model.output_tokens_price.meter_event_name,
333 payload: StripeCreateMeterEventPayload {
334 value: event.output_tokens as u64,
335 stripe_customer_id: customer_id,
336 },
337 timestamp: Some(timestamp),
338 },
339 )
340 .await?;
341 }
342
343 Ok(())
344 }
345
346 pub async fn checkout(
347 &self,
348 customer_id: stripe::CustomerId,
349 github_login: &str,
350 model: &StripeModel,
351 success_url: &str,
352 ) -> Result<String> {
353 let first_of_next_month = Utc::now()
354 .checked_add_months(chrono::Months::new(1))
355 .unwrap()
356 .with_day(1)
357 .unwrap();
358
359 let mut params = stripe::CreateCheckoutSession::new();
360 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
361 params.customer = Some(customer_id);
362 params.client_reference_id = Some(github_login);
363 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
364 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
365 ..Default::default()
366 });
367 params.line_items = Some(
368 [
369 &model.input_tokens_price.id,
370 &model.input_cache_creation_tokens_price.id,
371 &model.input_cache_read_tokens_price.id,
372 &model.output_tokens_price.id,
373 ]
374 .into_iter()
375 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
376 price: Some(price_id.to_string()),
377 ..Default::default()
378 })
379 .collect(),
380 );
381 params.success_url = Some(success_url);
382
383 let session = stripe::CheckoutSession::create(&self.client, params).await?;
384 Ok(session.url.context("no checkout session URL")?)
385 }
386
387 pub async fn checkout_with_price(
388 &self,
389 price_id: PriceId,
390 customer_id: stripe::CustomerId,
391 github_login: &str,
392 success_url: &str,
393 ) -> Result<String> {
394 let mut params = stripe::CreateCheckoutSession::new();
395 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
396 params.customer = Some(customer_id);
397 params.client_reference_id = Some(github_login);
398 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
399 price: Some(price_id.to_string()),
400 quantity: Some(1),
401 ..Default::default()
402 }]);
403 params.success_url = Some(success_url);
404
405 let session = stripe::CheckoutSession::create(&self.client, params).await?;
406 Ok(session.url.context("no checkout session URL")?)
407 }
408
409 pub async fn checkout_with_zed_pro_trial(
410 &self,
411 zed_pro_price_id: PriceId,
412 customer_id: stripe::CustomerId,
413 github_login: &str,
414 success_url: &str,
415 ) -> Result<String> {
416 let mut params = stripe::CreateCheckoutSession::new();
417 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
418 trial_period_days: Some(14),
419 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
420 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
421 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
422 }
423 }),
424 ..Default::default()
425 });
426 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
427 params.payment_method_collection =
428 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
429 params.customer = Some(customer_id);
430 params.client_reference_id = Some(github_login);
431 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
432 price: Some(zed_pro_price_id.to_string()),
433 quantity: Some(1),
434 ..Default::default()
435 }]);
436 params.success_url = Some(success_url);
437
438 let session = stripe::CheckoutSession::create(&self.client, params).await?;
439 Ok(session.url.context("no checkout session URL")?)
440 }
441}
442
443#[derive(Serialize)]
444struct DefaultAggregation {
445 formula: &'static str,
446}
447
448#[derive(Serialize)]
449struct StripeCreateMeterParams<'a> {
450 default_aggregation: DefaultAggregation,
451 display_name: String,
452 event_name: &'a str,
453}
454
455#[derive(Clone, Deserialize)]
456struct StripeMeter {
457 id: String,
458 event_name: String,
459}
460
461impl StripeMeter {
462 pub fn create(
463 client: &stripe::Client,
464 params: StripeCreateMeterParams,
465 ) -> stripe::Response<Self> {
466 client.post_form("/billing/meters", params)
467 }
468
469 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
470 #[derive(Serialize)]
471 struct Params {
472 #[serde(skip_serializing_if = "Option::is_none")]
473 limit: Option<u64>,
474 }
475
476 client.get_query("/billing/meters", Params { limit: Some(100) })
477 }
478}
479
480#[derive(Deserialize)]
481struct StripeMeterEvent {
482 identifier: String,
483}
484
485impl StripeMeterEvent {
486 pub async fn create(
487 client: &stripe::Client,
488 params: StripeCreateMeterEventParams<'_>,
489 ) -> Result<Self, stripe::StripeError> {
490 let identifier = params.identifier;
491 match client.post_form("/billing/meter_events", params).await {
492 Ok(event) => Ok(event),
493 Err(stripe::StripeError::Stripe(error)) => {
494 if error.http_status == 400
495 && error
496 .message
497 .as_ref()
498 .map_or(false, |message| message.contains(identifier))
499 {
500 Ok(Self {
501 identifier: identifier.to_string(),
502 })
503 } else {
504 Err(stripe::StripeError::Stripe(error))
505 }
506 }
507 Err(error) => Err(error),
508 }
509 }
510}
511
512#[derive(Serialize)]
513struct StripeCreateMeterEventParams<'a> {
514 identifier: &'a str,
515 event_name: &'a str,
516 payload: StripeCreateMeterEventPayload<'a>,
517 timestamp: Option<i64>,
518}
519
520#[derive(Serialize)]
521struct StripeCreateMeterEventPayload<'a> {
522 value: u64,
523 stripe_customer_id: &'a stripe::CustomerId,
524}
525
526fn subscription_contains_price(
527 subscription: &stripe::Subscription,
528 price_id: &stripe::PriceId,
529) -> bool {
530 subscription.items.data.iter().any(|item| {
531 item.price
532 .as_ref()
533 .map_or(false, |price| price.id == *price_id)
534 })
535}