1use std::sync::Arc;
2
3use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
4use crate::{Cents, Result};
5use anyhow::{Context as _, anyhow};
6use chrono::{Datelike, Utc};
7use collections::HashMap;
8use serde::{Deserialize, Serialize};
9use stripe::PriceId;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13pub struct StripeBilling {
14 state: RwLock<StripeBillingState>,
15 client: Arc<stripe::Client>,
16}
17
18#[derive(Default)]
19struct StripeBillingState {
20 meters_by_event_name: HashMap<String, StripeMeter>,
21 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
22 prices_by_lookup_key: HashMap<String, stripe::Price>,
23}
24
25pub struct StripeModelTokenPrices {
26 input_tokens_price: StripeBillingPrice,
27 input_cache_creation_tokens_price: StripeBillingPrice,
28 input_cache_read_tokens_price: StripeBillingPrice,
29 output_tokens_price: StripeBillingPrice,
30}
31
32struct StripeBillingPrice {
33 id: stripe::PriceId,
34 meter_event_name: String,
35}
36
37impl StripeBilling {
38 pub fn new(client: Arc<stripe::Client>) -> Self {
39 Self {
40 client,
41 state: RwLock::default(),
42 }
43 }
44
45 pub async fn initialize(&self) -> Result<()> {
46 log::info!("StripeBilling: initializing");
47
48 let mut state = self.state.write().await;
49
50 let (meters, prices) = futures::try_join!(
51 StripeMeter::list(&self.client),
52 stripe::Price::list(
53 &self.client,
54 &stripe::ListPrices {
55 limit: Some(100),
56 ..Default::default()
57 }
58 )
59 )?;
60
61 for meter in meters.data {
62 state
63 .meters_by_event_name
64 .insert(meter.event_name.clone(), meter);
65 }
66
67 for price in prices.data {
68 if let Some(lookup_key) = price.lookup_key.clone() {
69 state.prices_by_lookup_key.insert(lookup_key, price.clone());
70 }
71
72 if let Some(recurring) = price.recurring {
73 if let Some(meter) = recurring.meter {
74 state.price_ids_by_meter_id.insert(meter, price.id);
75 }
76 }
77 }
78
79 log::info!("StripeBilling: initialized");
80
81 Ok(())
82 }
83
84 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
85 self.find_price_id_by_lookup_key("zed-pro").await
86 }
87
88 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
89 self.find_price_id_by_lookup_key("zed-free").await
90 }
91
92 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
93 self.state
94 .read()
95 .await
96 .prices_by_lookup_key
97 .get(lookup_key)
98 .map(|price| price.id.clone())
99 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
100 }
101
102 pub async fn register_model_for_token_based_usage(
103 &self,
104 model: &llm::db::model::Model,
105 ) -> Result<StripeModelTokenPrices> {
106 let input_tokens_price = self
107 .get_or_insert_token_price(
108 &format!("model_{}/input_tokens", model.id),
109 &format!("{} (Input Tokens)", model.name),
110 Cents::new(model.price_per_million_input_tokens as u32),
111 )
112 .await?;
113 let input_cache_creation_tokens_price = self
114 .get_or_insert_token_price(
115 &format!("model_{}/input_cache_creation_tokens", model.id),
116 &format!("{} (Input Cache Creation Tokens)", model.name),
117 Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
118 )
119 .await?;
120 let input_cache_read_tokens_price = self
121 .get_or_insert_token_price(
122 &format!("model_{}/input_cache_read_tokens", model.id),
123 &format!("{} (Input Cache Read Tokens)", model.name),
124 Cents::new(model.price_per_million_cache_read_input_tokens as u32),
125 )
126 .await?;
127 let output_tokens_price = self
128 .get_or_insert_token_price(
129 &format!("model_{}/output_tokens", model.id),
130 &format!("{} (Output Tokens)", model.name),
131 Cents::new(model.price_per_million_output_tokens as u32),
132 )
133 .await?;
134 Ok(StripeModelTokenPrices {
135 input_tokens_price,
136 input_cache_creation_tokens_price,
137 input_cache_read_tokens_price,
138 output_tokens_price,
139 })
140 }
141
142 async fn get_or_insert_token_price(
143 &self,
144 meter_event_name: &str,
145 price_description: &str,
146 price_per_million_tokens: Cents,
147 ) -> Result<StripeBillingPrice> {
148 // Fast code path when the meter and the price already exist.
149 {
150 let state = self.state.read().await;
151 if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
152 if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
153 return Ok(StripeBillingPrice {
154 id: price_id.clone(),
155 meter_event_name: meter_event_name.to_string(),
156 });
157 }
158 }
159 }
160
161 let mut state = self.state.write().await;
162 let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
163 meter.clone()
164 } else {
165 let meter = StripeMeter::create(
166 &self.client,
167 StripeCreateMeterParams {
168 default_aggregation: DefaultAggregation { formula: "sum" },
169 display_name: price_description.to_string(),
170 event_name: meter_event_name,
171 },
172 )
173 .await?;
174 state
175 .meters_by_event_name
176 .insert(meter_event_name.to_string(), meter.clone());
177 meter
178 };
179
180 let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
181 price_id.clone()
182 } else {
183 let price = stripe::Price::create(
184 &self.client,
185 stripe::CreatePrice {
186 active: Some(true),
187 billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
188 currency: stripe::Currency::USD,
189 currency_options: None,
190 custom_unit_amount: None,
191 expand: &[],
192 lookup_key: None,
193 metadata: None,
194 nickname: None,
195 product: None,
196 product_data: Some(stripe::CreatePriceProductData {
197 id: None,
198 active: Some(true),
199 metadata: None,
200 name: price_description.to_string(),
201 statement_descriptor: None,
202 tax_code: None,
203 unit_label: None,
204 }),
205 recurring: Some(stripe::CreatePriceRecurring {
206 aggregate_usage: None,
207 interval: stripe::CreatePriceRecurringInterval::Month,
208 interval_count: None,
209 trial_period_days: None,
210 usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
211 meter: Some(meter.id.clone()),
212 }),
213 tax_behavior: None,
214 tiers: None,
215 tiers_mode: None,
216 transfer_lookup_key: None,
217 transform_quantity: None,
218 unit_amount: None,
219 unit_amount_decimal: Some(&format!(
220 "{:.12}",
221 price_per_million_tokens.0 as f64 / 1_000_000f64
222 )),
223 },
224 )
225 .await?;
226 state
227 .price_ids_by_meter_id
228 .insert(meter.id, price.id.clone());
229 price.id
230 };
231
232 Ok(StripeBillingPrice {
233 id: price_id,
234 meter_event_name: meter_event_name.to_string(),
235 })
236 }
237
238 pub async fn subscribe_to_price(
239 &self,
240 subscription_id: &stripe::SubscriptionId,
241 price_id: &stripe::PriceId,
242 ) -> Result<()> {
243 let subscription =
244 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
245
246 if subscription_contains_price(&subscription, price_id) {
247 return Ok(());
248 }
249
250 stripe::Subscription::update(
251 &self.client,
252 subscription_id,
253 stripe::UpdateSubscription {
254 items: Some(vec![stripe::UpdateSubscriptionItems {
255 price: Some(price_id.to_string()),
256 ..Default::default()
257 }]),
258 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
259 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
260 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
261 },
262 }),
263 ..Default::default()
264 },
265 )
266 .await?;
267
268 Ok(())
269 }
270
271 pub async fn subscribe_to_model(
272 &self,
273 subscription_id: &stripe::SubscriptionId,
274 model: &StripeModelTokenPrices,
275 ) -> Result<()> {
276 let subscription =
277 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
278
279 let mut items = Vec::new();
280
281 if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
282 items.push(stripe::UpdateSubscriptionItems {
283 price: Some(model.input_tokens_price.id.to_string()),
284 ..Default::default()
285 });
286 }
287
288 if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
289 {
290 items.push(stripe::UpdateSubscriptionItems {
291 price: Some(model.input_cache_creation_tokens_price.id.to_string()),
292 ..Default::default()
293 });
294 }
295
296 if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
297 items.push(stripe::UpdateSubscriptionItems {
298 price: Some(model.input_cache_read_tokens_price.id.to_string()),
299 ..Default::default()
300 });
301 }
302
303 if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
304 items.push(stripe::UpdateSubscriptionItems {
305 price: Some(model.output_tokens_price.id.to_string()),
306 ..Default::default()
307 });
308 }
309
310 if !items.is_empty() {
311 items.extend(subscription.items.data.iter().map(|item| {
312 stripe::UpdateSubscriptionItems {
313 id: Some(item.id.to_string()),
314 ..Default::default()
315 }
316 }));
317
318 stripe::Subscription::update(
319 &self.client,
320 subscription_id,
321 stripe::UpdateSubscription {
322 items: Some(items),
323 ..Default::default()
324 },
325 )
326 .await?;
327 }
328
329 Ok(())
330 }
331
332 pub async fn bill_model_token_usage(
333 &self,
334 customer_id: &stripe::CustomerId,
335 model: &StripeModelTokenPrices,
336 event: &llm::db::billing_event::Model,
337 ) -> Result<()> {
338 let timestamp = Utc::now().timestamp();
339
340 if event.input_tokens > 0 {
341 StripeMeterEvent::create(
342 &self.client,
343 StripeCreateMeterEventParams {
344 identifier: &format!("input_tokens/{}", event.idempotency_key),
345 event_name: &model.input_tokens_price.meter_event_name,
346 payload: StripeCreateMeterEventPayload {
347 value: event.input_tokens as u64,
348 stripe_customer_id: customer_id,
349 },
350 timestamp: Some(timestamp),
351 },
352 )
353 .await?;
354 }
355
356 if event.input_cache_creation_tokens > 0 {
357 StripeMeterEvent::create(
358 &self.client,
359 StripeCreateMeterEventParams {
360 identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
361 event_name: &model.input_cache_creation_tokens_price.meter_event_name,
362 payload: StripeCreateMeterEventPayload {
363 value: event.input_cache_creation_tokens as u64,
364 stripe_customer_id: customer_id,
365 },
366 timestamp: Some(timestamp),
367 },
368 )
369 .await?;
370 }
371
372 if event.input_cache_read_tokens > 0 {
373 StripeMeterEvent::create(
374 &self.client,
375 StripeCreateMeterEventParams {
376 identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
377 event_name: &model.input_cache_read_tokens_price.meter_event_name,
378 payload: StripeCreateMeterEventPayload {
379 value: event.input_cache_read_tokens as u64,
380 stripe_customer_id: customer_id,
381 },
382 timestamp: Some(timestamp),
383 },
384 )
385 .await?;
386 }
387
388 if event.output_tokens > 0 {
389 StripeMeterEvent::create(
390 &self.client,
391 StripeCreateMeterEventParams {
392 identifier: &format!("output_tokens/{}", event.idempotency_key),
393 event_name: &model.output_tokens_price.meter_event_name,
394 payload: StripeCreateMeterEventPayload {
395 value: event.output_tokens as u64,
396 stripe_customer_id: customer_id,
397 },
398 timestamp: Some(timestamp),
399 },
400 )
401 .await?;
402 }
403
404 Ok(())
405 }
406
407 pub async fn bill_model_request_usage(
408 &self,
409 customer_id: &stripe::CustomerId,
410 event_name: &str,
411 requests: i32,
412 ) -> Result<()> {
413 let timestamp = Utc::now().timestamp();
414 let idempotency_key = Uuid::new_v4();
415
416 StripeMeterEvent::create(
417 &self.client,
418 StripeCreateMeterEventParams {
419 identifier: &format!("model_requests/{}", idempotency_key),
420 event_name,
421 payload: StripeCreateMeterEventPayload {
422 value: requests as u64,
423 stripe_customer_id: customer_id,
424 },
425 timestamp: Some(timestamp),
426 },
427 )
428 .await?;
429
430 Ok(())
431 }
432
433 pub async fn checkout(
434 &self,
435 customer_id: stripe::CustomerId,
436 github_login: &str,
437 model: &StripeModelTokenPrices,
438 success_url: &str,
439 ) -> Result<String> {
440 let first_of_next_month = Utc::now()
441 .checked_add_months(chrono::Months::new(1))
442 .unwrap()
443 .with_day(1)
444 .unwrap();
445
446 let mut params = stripe::CreateCheckoutSession::new();
447 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
448 params.customer = Some(customer_id);
449 params.client_reference_id = Some(github_login);
450 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
451 billing_cycle_anchor: Some(first_of_next_month.timestamp()),
452 ..Default::default()
453 });
454 params.line_items = Some(
455 [
456 &model.input_tokens_price.id,
457 &model.input_cache_creation_tokens_price.id,
458 &model.input_cache_read_tokens_price.id,
459 &model.output_tokens_price.id,
460 ]
461 .into_iter()
462 .map(|price_id| stripe::CreateCheckoutSessionLineItems {
463 price: Some(price_id.to_string()),
464 ..Default::default()
465 })
466 .collect(),
467 );
468 params.success_url = Some(success_url);
469
470 let session = stripe::CheckoutSession::create(&self.client, params).await?;
471 Ok(session.url.context("no checkout session URL")?)
472 }
473
474 pub async fn checkout_with_zed_pro(
475 &self,
476 customer_id: stripe::CustomerId,
477 github_login: &str,
478 success_url: &str,
479 ) -> Result<String> {
480 let zed_pro_price_id = self.zed_pro_price_id().await?;
481
482 let mut params = stripe::CreateCheckoutSession::new();
483 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
484 params.customer = Some(customer_id);
485 params.client_reference_id = Some(github_login);
486 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
487 price: Some(zed_pro_price_id.to_string()),
488 quantity: Some(1),
489 ..Default::default()
490 }]);
491 params.success_url = Some(success_url);
492
493 let session = stripe::CheckoutSession::create(&self.client, params).await?;
494 Ok(session.url.context("no checkout session URL")?)
495 }
496
497 pub async fn checkout_with_zed_pro_trial(
498 &self,
499 customer_id: stripe::CustomerId,
500 github_login: &str,
501 feature_flags: Vec<String>,
502 success_url: &str,
503 ) -> Result<String> {
504 let zed_pro_price_id = self.zed_pro_price_id().await?;
505
506 let eligible_for_extended_trial = feature_flags
507 .iter()
508 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
509
510 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
511
512 let mut subscription_metadata = std::collections::HashMap::new();
513 if eligible_for_extended_trial {
514 subscription_metadata.insert(
515 "promo_feature_flag".to_string(),
516 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
517 );
518 }
519
520 let mut params = stripe::CreateCheckoutSession::new();
521 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
522 trial_period_days: Some(trial_period_days),
523 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
524 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
525 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
526 }
527 }),
528 metadata: if !subscription_metadata.is_empty() {
529 Some(subscription_metadata)
530 } else {
531 None
532 },
533 ..Default::default()
534 });
535 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
536 params.payment_method_collection =
537 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
538 params.customer = Some(customer_id);
539 params.client_reference_id = Some(github_login);
540 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
541 price: Some(zed_pro_price_id.to_string()),
542 quantity: Some(1),
543 ..Default::default()
544 }]);
545 params.success_url = Some(success_url);
546
547 let session = stripe::CheckoutSession::create(&self.client, params).await?;
548 Ok(session.url.context("no checkout session URL")?)
549 }
550}
551
552#[derive(Serialize)]
553struct DefaultAggregation {
554 formula: &'static str,
555}
556
557#[derive(Serialize)]
558struct StripeCreateMeterParams<'a> {
559 default_aggregation: DefaultAggregation,
560 display_name: String,
561 event_name: &'a str,
562}
563
564#[derive(Clone, Deserialize)]
565struct StripeMeter {
566 id: String,
567 event_name: String,
568}
569
570impl StripeMeter {
571 pub fn create(
572 client: &stripe::Client,
573 params: StripeCreateMeterParams,
574 ) -> stripe::Response<Self> {
575 client.post_form("/billing/meters", params)
576 }
577
578 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
579 #[derive(Serialize)]
580 struct Params {
581 #[serde(skip_serializing_if = "Option::is_none")]
582 limit: Option<u64>,
583 }
584
585 client.get_query("/billing/meters", Params { limit: Some(100) })
586 }
587}
588
589#[derive(Deserialize)]
590struct StripeMeterEvent {
591 identifier: String,
592}
593
594impl StripeMeterEvent {
595 pub async fn create(
596 client: &stripe::Client,
597 params: StripeCreateMeterEventParams<'_>,
598 ) -> Result<Self, stripe::StripeError> {
599 let identifier = params.identifier;
600 match client.post_form("/billing/meter_events", params).await {
601 Ok(event) => Ok(event),
602 Err(stripe::StripeError::Stripe(error)) => {
603 if error.http_status == 400
604 && error
605 .message
606 .as_ref()
607 .map_or(false, |message| message.contains(identifier))
608 {
609 Ok(Self {
610 identifier: identifier.to_string(),
611 })
612 } else {
613 Err(stripe::StripeError::Stripe(error))
614 }
615 }
616 Err(error) => Err(error),
617 }
618 }
619}
620
621#[derive(Serialize)]
622struct StripeCreateMeterEventParams<'a> {
623 identifier: &'a str,
624 event_name: &'a str,
625 payload: StripeCreateMeterEventPayload<'a>,
626 timestamp: Option<i64>,
627}
628
629#[derive(Serialize)]
630struct StripeCreateMeterEventPayload<'a> {
631 value: u64,
632 stripe_customer_id: &'a stripe::CustomerId,
633}
634
635fn subscription_contains_price(
636 subscription: &stripe::Subscription,
637 price_id: &stripe::PriceId,
638) -> bool {
639 subscription.items.data.iter().any(|item| {
640 item.price
641 .as_ref()
642 .map_or(false, |price| price.id == *price_id)
643 })
644}