1use std::sync::Arc;
2
3use crate::{Cents, Result, llm};
4use anyhow::{Context as _, anyhow};
5use chrono::{Datelike, Utc};
6use collections::HashMap;
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9
10pub struct StripeBilling {
11 state: RwLock<StripeBillingState>,
12 client: Arc<stripe::Client>,
13 zed_pro_price_id: Option<String>,
14}
15
16#[derive(Default)]
17struct StripeBillingState {
18 meters_by_event_name: HashMap<String, StripeMeter>,
19 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
20}
21
22pub struct StripeModel {
23 input_tokens_price: StripeBillingPrice,
24 input_cache_creation_tokens_price: StripeBillingPrice,
25 input_cache_read_tokens_price: StripeBillingPrice,
26 output_tokens_price: StripeBillingPrice,
27}
28
29struct StripeBillingPrice {
30 id: stripe::PriceId,
31 meter_event_name: String,
32}
33
34impl StripeBilling {
35 pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
36 Self {
37 client,
38 state: RwLock::default(),
39 zed_pro_price_id,
40 }
41 }
42
43 pub async fn initialize(&self) -> Result<()> {
44 log::info!("StripeBilling: initializing");
45
46 let mut state = self.state.write().await;
47
48 let (meters, prices) = futures::try_join!(
49 StripeMeter::list(&self.client),
50 stripe::Price::list(
51 &self.client,
52 &stripe::ListPrices {
53 limit: Some(100),
54 ..Default::default()
55 }
56 )
57 )?;
58
59 for meter in meters.data {
60 state
61 .meters_by_event_name
62 .insert(meter.event_name.clone(), meter);
63 }
64
65 for price in prices.data {
66 if let Some(recurring) = price.recurring {
67 if let Some(meter) = recurring.meter {
68 state.price_ids_by_meter_id.insert(meter, price.id);
69 }
70 }
71 }
72
73 log::info!("StripeBilling: initialized");
74
75 Ok(())
76 }
77
78 pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
79 let input_tokens_price = self
80 .get_or_insert_price(
81 &format!("model_{}/input_tokens", model.id),
82 &format!("{} (Input Tokens)", model.name),
83 Cents::new(model.price_per_million_input_tokens as u32),
84 )
85 .await?;
86 let input_cache_creation_tokens_price = self
87 .get_or_insert_price(
88 &format!("model_{}/input_cache_creation_tokens", model.id),
89 &format!("{} (Input Cache Creation Tokens)", model.name),
90 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
91 )
92 .await?;
93 let input_cache_read_tokens_price = self
94 .get_or_insert_price(
95 &format!("model_{}/input_cache_read_tokens", model.id),
96 &format!("{} (Input Cache Read Tokens)", model.name),
97 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
98 )
99 .await?;
100 let output_tokens_price = self
101 .get_or_insert_price(
102 &format!("model_{}/output_tokens", model.id),
103 &format!("{} (Output Tokens)", model.name),
104 Cents::new(model.price_per_million_output_tokens as u32),
105 )
106 .await?;
107 Ok(StripeModel {
108 input_tokens_price,
109 input_cache_creation_tokens_price,
110 input_cache_read_tokens_price,
111 output_tokens_price,
112 })
113 }
114
115 async fn get_or_insert_price(
116 &self,
117 meter_event_name: &str,
118 price_description: &str,
119 price_per_million_tokens: Cents,
120 ) -> Result<StripeBillingPrice> {
121 // Fast code path when the meter and the price already exist.
122 {
123 let state = self.state.read().await;
124 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
125 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
126 return Ok(StripeBillingPrice {
127 id: price_id.clone(),
128 meter_event_name: meter_event_name.to_string(),
129 });
130 }
131 }
132 }
133
134 let mut state = self.state.write().await;
135 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
136 meter.clone()
137 } else {
138 let meter = StripeMeter::create(
139 &self.client,
140 StripeCreateMeterParams {
141 default_aggregation: DefaultAggregation { formula: "sum" },
142 display_name: price_description.to_string(),
143 event_name: meter_event_name,
144 },
145 )
146 .await?;
147 state
148 .meters_by_event_name
149 .insert(meter_event_name.to_string(), meter.clone());
150 meter
151 };
152
153 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
154 price_id.clone()
155 } else {
156 let price = stripe::Price::create(
157 &self.client,
158 stripe::CreatePrice {
159 active: Some(true),
160 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
161 currency: stripe::Currency::USD,
162 currency_options: None,
163 custom_unit_amount: None,
164 expand: &[],
165 lookup_key: None,
166 metadata: None,
167 nickname: None,
168 product: None,
169 product_data: Some(stripe::CreatePriceProductData {
170 id: None,
171 active: Some(true),
172 metadata: None,
173 name: price_description.to_string(),
174 statement_descriptor: None,
175 tax_code: None,
176 unit_label: None,
177 }),
178 recurring: Some(stripe::CreatePriceRecurring {
179 aggregate_usage: None,
180 interval: stripe::CreatePriceRecurringInterval::Month,
181 interval_count: None,
182 trial_period_days: None,
183 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
184 meter: Some(meter.id.clone()),
185 }),
186 tax_behavior: None,
187 tiers: None,
188 tiers_mode: None,
189 transfer_lookup_key: None,
190 transform_quantity: None,
191 unit_amount: None,
192 unit_amount_decimal: Some(&format!(
193 "{:.12}",
194 price_per_million_tokens.0 as f64 / 1_000_000f64
195 )),
196 },
197 )
198 .await?;
199 state
200 .price_ids_by_meter_id
201 .insert(meter.id, price.id.clone());
202 price.id
203 };
204
205 Ok(StripeBillingPrice {
206 id: price_id,
207 meter_event_name: meter_event_name.to_string(),
208 })
209 }
210
211 pub async fn subscribe_to_model(
212 &self,
213 subscription_id: &stripe::SubscriptionId,
214 model: &StripeModel,
215 ) -> Result<()> {
216 let subscription =
217 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
218
219 let mut items = Vec::new();
220
221 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
222 items.push(stripe::UpdateSubscriptionItems {
223 price: Some(model.input_tokens_price.id.to_string()),
224 ..Default::default()
225 });
226 }
227
228 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
229 {
230 items.push(stripe::UpdateSubscriptionItems {
231 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
232 ..Default::default()
233 });
234 }
235
236 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
237 items.push(stripe::UpdateSubscriptionItems {
238 price: Some(model.input_cache_read_tokens_price.id.to_string()),
239 ..Default::default()
240 });
241 }
242
243 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
244 items.push(stripe::UpdateSubscriptionItems {
245 price: Some(model.output_tokens_price.id.to_string()),
246 ..Default::default()
247 });
248 }
249
250 if !items.is_empty() {
251 items.extend(subscription.items.data.iter().map(|item| {
252 stripe::UpdateSubscriptionItems {
253 id: Some(item.id.to_string()),
254 ..Default::default()
255 }
256 }));
257
258 stripe::Subscription::update(
259 &self.client,
260 subscription_id,
261 stripe::UpdateSubscription {
262 items: Some(items),
263 ..Default::default()
264 },
265 )
266 .await?;
267 }
268
269 Ok(())
270 }
271
272 pub async fn bill_model_usage(
273 &self,
274 customer_id: &stripe::CustomerId,
275 model: &StripeModel,
276 event: &llm::db::billing_event::Model,
277 ) -> Result<()> {
278 let timestamp = Utc::now().timestamp();
279
280 if event.input_tokens > 0 {
281 StripeMeterEvent::create(
282 &self.client,
283 StripeCreateMeterEventParams {
284 identifier: &format!("input_tokens/{}", event.idempotency_key),
285 event_name: &model.input_tokens_price.meter_event_name,
286 payload: StripeCreateMeterEventPayload {
287 value: event.input_tokens as u64,
288 stripe_customer_id: customer_id,
289 },
290 timestamp: Some(timestamp),
291 },
292 )
293 .await?;
294 }
295
296 if event.input_cache_creation_tokens > 0 {
297 StripeMeterEvent::create(
298 &self.client,
299 StripeCreateMeterEventParams {
300 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
301 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
302 payload: StripeCreateMeterEventPayload {
303 value: event.input_cache_creation_tokens as u64,
304 stripe_customer_id: customer_id,
305 },
306 timestamp: Some(timestamp),
307 },
308 )
309 .await?;
310 }
311
312 if event.input_cache_read_tokens > 0 {
313 StripeMeterEvent::create(
314 &self.client,
315 StripeCreateMeterEventParams {
316 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
317 event_name: &model.input_cache_read_tokens_price.meter_event_name,
318 payload: StripeCreateMeterEventPayload {
319 value: event.input_cache_read_tokens as u64,
320 stripe_customer_id: customer_id,
321 },
322 timestamp: Some(timestamp),
323 },
324 )
325 .await?;
326 }
327
328 if event.output_tokens > 0 {
329 StripeMeterEvent::create(
330 &self.client,
331 StripeCreateMeterEventParams {
332 identifier: &format!("output_tokens/{}", event.idempotency_key),
333 event_name: &model.output_tokens_price.meter_event_name,
334 payload: StripeCreateMeterEventPayload {
335 value: event.output_tokens as u64,
336 stripe_customer_id: customer_id,
337 },
338 timestamp: Some(timestamp),
339 },
340 )
341 .await?;
342 }
343
344 Ok(())
345 }
346
347 pub async fn checkout(
348 &self,
349 customer_id: stripe::CustomerId,
350 github_login: &str,
351 model: &StripeModel,
352 success_url: &str,
353 ) -> Result<String> {
354 let first_of_next_month = Utc::now()
355 .checked_add_months(chrono::Months::new(1))
356 .unwrap()
357 .with_day(1)
358 .unwrap();
359
360 let mut params = stripe::CreateCheckoutSession::new();
361 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
362 params.customer = Some(customer_id);
363 params.client_reference_id = Some(github_login);
364 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
365 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
366 ..Default::default()
367 });
368 params.line_items = Some(
369 [
370 &model.input_tokens_price.id,
371 &model.input_cache_creation_tokens_price.id,
372 &model.input_cache_read_tokens_price.id,
373 &model.output_tokens_price.id,
374 ]
375 .into_iter()
376 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
377 price: Some(price_id.to_string()),
378 ..Default::default()
379 })
380 .collect(),
381 );
382 params.success_url = Some(success_url);
383
384 let session = stripe::CheckoutSession::create(&self.client, params).await?;
385 Ok(session.url.context("no checkout session URL")?)
386 }
387
388 pub async fn checkout_with_zed_pro(
389 &self,
390 customer_id: stripe::CustomerId,
391 github_login: &str,
392 success_url: &str,
393 ) -> Result<String> {
394 let zed_pro_price_id = self
395 .zed_pro_price_id
396 .as_ref()
397 .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
398
399 let mut params = stripe::CreateCheckoutSession::new();
400 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
401 params.customer = Some(customer_id);
402 params.client_reference_id = Some(github_login);
403 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
404 price: Some(zed_pro_price_id.clone()),
405 quantity: Some(1),
406 ..Default::default()
407 }]);
408 params.success_url = Some(success_url);
409
410 let session = stripe::CheckoutSession::create(&self.client, params).await?;
411 Ok(session.url.context("no checkout session URL")?)
412 }
413}
414
415#[derive(Serialize)]
416struct DefaultAggregation {
417 formula: &'static str,
418}
419
420#[derive(Serialize)]
421struct StripeCreateMeterParams<'a> {
422 default_aggregation: DefaultAggregation,
423 display_name: String,
424 event_name: &'a str,
425}
426
427#[derive(Clone, Deserialize)]
428struct StripeMeter {
429 id: String,
430 event_name: String,
431}
432
433impl StripeMeter {
434 pub fn create(
435 client: &stripe::Client,
436 params: StripeCreateMeterParams,
437 ) -> stripe::Response<Self> {
438 client.post_form("/billing/meters", params)
439 }
440
441 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
442 #[derive(Serialize)]
443 struct Params {
444 #[serde(skip_serializing_if = "Option::is_none")]
445 limit: Option<u64>,
446 }
447
448 client.get_query("/billing/meters", Params { limit: Some(100) })
449 }
450}
451
452#[derive(Deserialize)]
453struct StripeMeterEvent {
454 identifier: String,
455}
456
457impl StripeMeterEvent {
458 pub async fn create(
459 client: &stripe::Client,
460 params: StripeCreateMeterEventParams<'_>,
461 ) -> Result<Self, stripe::StripeError> {
462 let identifier = params.identifier;
463 match client.post_form("/billing/meter_events", params).await {
464 Ok(event) => Ok(event),
465 Err(stripe::StripeError::Stripe(error)) => {
466 if error.http_status == 400
467 && error
468 .message
469 .as_ref()
470 .map_or(false, |message| message.contains(identifier))
471 {
472 Ok(Self {
473 identifier: identifier.to_string(),
474 })
475 } else {
476 Err(stripe::StripeError::Stripe(error))
477 }
478 }
479 Err(error) => Err(error),
480 }
481 }
482}
483
484#[derive(Serialize)]
485struct StripeCreateMeterEventParams<'a> {
486 identifier: &'a str,
487 event_name: &'a str,
488 payload: StripeCreateMeterEventPayload<'a>,
489 timestamp: Option<i64>,
490}
491
492#[derive(Serialize)]
493struct StripeCreateMeterEventPayload<'a> {
494 value: u64,
495 stripe_customer_id: &'a stripe::CustomerId,
496}
497
498fn subscription_contains_price(
499 subscription: &stripe::Subscription,
500 price_id: &stripe::PriceId,
501) -> bool {
502 subscription.items.data.iter().any(|item| {
503 item.price
504 .as_ref()
505 .map_or(false, |price| price.id == *price_id)
506 })
507}