1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::{anyhow, bail, Context};
6use axum::{
7 extract::{self, Query},
8 routing::{get, post},
9 Extension, Json, Router,
10};
11use chrono::{DateTime, SecondsFormat, Utc};
12use reqwest::StatusCode;
13use sea_orm::ActiveValue;
14use serde::{Deserialize, Serialize};
15use stripe::{
16 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
17 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
18 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
19 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
20 CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
21 Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::ResultExt;
24
25use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
26use crate::db::{
27 billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
28 CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
29 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
30};
31use crate::llm::db::LlmDatabase;
32use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
33use crate::rpc::ResultExt as _;
34use crate::{AppState, Error, Result};
35
36pub fn router() -> Router {
37 Router::new()
38 .route(
39 "/billing/preferences",
40 get(get_billing_preferences).put(update_billing_preferences),
41 )
42 .route(
43 "/billing/subscriptions",
44 get(list_billing_subscriptions).post(create_billing_subscription),
45 )
46 .route(
47 "/billing/subscriptions/manage",
48 post(manage_billing_subscription),
49 )
50}
51
52#[derive(Debug, Deserialize)]
53struct GetBillingPreferencesParams {
54 github_user_id: i32,
55}
56
57#[derive(Debug, Serialize)]
58struct BillingPreferencesResponse {
59 max_monthly_llm_usage_spending_in_cents: i32,
60}
61
62async fn get_billing_preferences(
63 Extension(app): Extension<Arc<AppState>>,
64 Query(params): Query<GetBillingPreferencesParams>,
65) -> Result<Json<BillingPreferencesResponse>> {
66 let user = app
67 .db
68 .get_user_by_github_user_id(params.github_user_id)
69 .await?
70 .ok_or_else(|| anyhow!("user not found"))?;
71
72 let preferences = app.db.get_billing_preferences(user.id).await?;
73
74 Ok(Json(BillingPreferencesResponse {
75 max_monthly_llm_usage_spending_in_cents: preferences
76 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
77 preferences.max_monthly_llm_usage_spending_in_cents
78 }),
79 }))
80}
81
82#[derive(Debug, Deserialize)]
83struct UpdateBillingPreferencesBody {
84 github_user_id: i32,
85 max_monthly_llm_usage_spending_in_cents: i32,
86}
87
88async fn update_billing_preferences(
89 Extension(app): Extension<Arc<AppState>>,
90 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
91) -> Result<Json<BillingPreferencesResponse>> {
92 let user = app
93 .db
94 .get_user_by_github_user_id(body.github_user_id)
95 .await?
96 .ok_or_else(|| anyhow!("user not found"))?;
97
98 let billing_preferences =
99 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
100 app.db
101 .update_billing_preferences(
102 user.id,
103 &UpdateBillingPreferencesParams {
104 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
105 body.max_monthly_llm_usage_spending_in_cents,
106 ),
107 },
108 )
109 .await?
110 } else {
111 app.db
112 .create_billing_preferences(
113 user.id,
114 &crate::db::CreateBillingPreferencesParams {
115 max_monthly_llm_usage_spending_in_cents: body
116 .max_monthly_llm_usage_spending_in_cents,
117 },
118 )
119 .await?
120 };
121
122 Ok(Json(BillingPreferencesResponse {
123 max_monthly_llm_usage_spending_in_cents: billing_preferences
124 .max_monthly_llm_usage_spending_in_cents,
125 }))
126}
127
128#[derive(Debug, Deserialize)]
129struct ListBillingSubscriptionsParams {
130 github_user_id: i32,
131}
132
133#[derive(Debug, Serialize)]
134struct BillingSubscriptionJson {
135 id: BillingSubscriptionId,
136 name: String,
137 status: StripeSubscriptionStatus,
138 cancel_at: Option<String>,
139 /// Whether this subscription can be canceled.
140 is_cancelable: bool,
141}
142
143#[derive(Debug, Serialize)]
144struct ListBillingSubscriptionsResponse {
145 subscriptions: Vec<BillingSubscriptionJson>,
146}
147
148async fn list_billing_subscriptions(
149 Extension(app): Extension<Arc<AppState>>,
150 Query(params): Query<ListBillingSubscriptionsParams>,
151) -> Result<Json<ListBillingSubscriptionsResponse>> {
152 let user = app
153 .db
154 .get_user_by_github_user_id(params.github_user_id)
155 .await?
156 .ok_or_else(|| anyhow!("user not found"))?;
157
158 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
159
160 Ok(Json(ListBillingSubscriptionsResponse {
161 subscriptions: subscriptions
162 .into_iter()
163 .map(|subscription| BillingSubscriptionJson {
164 id: subscription.id,
165 name: "Zed LLM Usage".to_string(),
166 status: subscription.stripe_subscription_status,
167 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
168 cancel_at
169 .and_utc()
170 .to_rfc3339_opts(SecondsFormat::Millis, true)
171 }),
172 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
173 && subscription.stripe_cancel_at.is_none(),
174 })
175 .collect(),
176 }))
177}
178
179#[derive(Debug, Deserialize)]
180struct CreateBillingSubscriptionBody {
181 github_user_id: i32,
182}
183
184#[derive(Debug, Serialize)]
185struct CreateBillingSubscriptionResponse {
186 checkout_session_url: String,
187}
188
189/// Initiates a Stripe Checkout session for creating a billing subscription.
190async fn create_billing_subscription(
191 Extension(app): Extension<Arc<AppState>>,
192 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
193) -> Result<Json<CreateBillingSubscriptionResponse>> {
194 let user = app
195 .db
196 .get_user_by_github_user_id(body.github_user_id)
197 .await?
198 .ok_or_else(|| anyhow!("user not found"))?;
199
200 let Some((stripe_client, stripe_access_price_id)) = app
201 .stripe_client
202 .clone()
203 .zip(app.config.stripe_llm_access_price_id.clone())
204 else {
205 log::error!("failed to retrieve Stripe client or price ID");
206 Err(Error::http(
207 StatusCode::NOT_IMPLEMENTED,
208 "not supported".into(),
209 ))?
210 };
211
212 let customer_id =
213 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
214 CustomerId::from_str(&existing_customer.stripe_customer_id)
215 .context("failed to parse customer ID")?
216 } else {
217 let customer = Customer::create(
218 &stripe_client,
219 CreateCustomer {
220 email: user.email_address.as_deref(),
221 ..Default::default()
222 },
223 )
224 .await?;
225
226 customer.id
227 };
228
229 let checkout_session = {
230 let mut params = CreateCheckoutSession::new();
231 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
232 params.customer = Some(customer_id);
233 params.client_reference_id = Some(user.github_login.as_str());
234 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
235 price: Some(stripe_access_price_id.to_string()),
236 quantity: Some(1),
237 ..Default::default()
238 }]);
239 let success_url = format!("{}/account", app.config.zed_dot_dev_url());
240 params.success_url = Some(&success_url);
241
242 CheckoutSession::create(&stripe_client, params).await?
243 };
244
245 Ok(Json(CreateBillingSubscriptionResponse {
246 checkout_session_url: checkout_session
247 .url
248 .ok_or_else(|| anyhow!("no checkout session URL"))?,
249 }))
250}
251
252#[derive(Debug, PartialEq, Deserialize)]
253#[serde(rename_all = "snake_case")]
254enum ManageSubscriptionIntent {
255 /// The user intends to cancel their subscription.
256 Cancel,
257 /// The user intends to stop the cancellation of their subscription.
258 StopCancellation,
259}
260
261#[derive(Debug, Deserialize)]
262struct ManageBillingSubscriptionBody {
263 github_user_id: i32,
264 intent: ManageSubscriptionIntent,
265 /// The ID of the subscription to manage.
266 subscription_id: BillingSubscriptionId,
267}
268
269#[derive(Debug, Serialize)]
270struct ManageBillingSubscriptionResponse {
271 billing_portal_session_url: Option<String>,
272}
273
274/// Initiates a Stripe customer portal session for managing a billing subscription.
275async fn manage_billing_subscription(
276 Extension(app): Extension<Arc<AppState>>,
277 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
278) -> Result<Json<ManageBillingSubscriptionResponse>> {
279 let user = app
280 .db
281 .get_user_by_github_user_id(body.github_user_id)
282 .await?
283 .ok_or_else(|| anyhow!("user not found"))?;
284
285 let Some(stripe_client) = app.stripe_client.clone() else {
286 log::error!("failed to retrieve Stripe client");
287 Err(Error::http(
288 StatusCode::NOT_IMPLEMENTED,
289 "not supported".into(),
290 ))?
291 };
292
293 let customer = app
294 .db
295 .get_billing_customer_by_user_id(user.id)
296 .await?
297 .ok_or_else(|| anyhow!("billing customer not found"))?;
298 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
299 .context("failed to parse customer ID")?;
300
301 let subscription = app
302 .db
303 .get_billing_subscription_by_id(body.subscription_id)
304 .await?
305 .ok_or_else(|| anyhow!("subscription not found"))?;
306
307 if body.intent == ManageSubscriptionIntent::StopCancellation {
308 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
309 .context("failed to parse subscription ID")?;
310
311 let updated_stripe_subscription = Subscription::update(
312 &stripe_client,
313 &subscription_id,
314 stripe::UpdateSubscription {
315 cancel_at_period_end: Some(false),
316 ..Default::default()
317 },
318 )
319 .await?;
320
321 app.db
322 .update_billing_subscription(
323 subscription.id,
324 &UpdateBillingSubscriptionParams {
325 stripe_cancel_at: ActiveValue::set(
326 updated_stripe_subscription
327 .cancel_at
328 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
329 .map(|time| time.naive_utc()),
330 ),
331 ..Default::default()
332 },
333 )
334 .await?;
335
336 return Ok(Json(ManageBillingSubscriptionResponse {
337 billing_portal_session_url: None,
338 }));
339 }
340
341 let flow = match body.intent {
342 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
343 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
344 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
345 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
346 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
347 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
348 }),
349 ..Default::default()
350 }),
351 subscription_cancel: Some(
352 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
353 subscription: subscription.stripe_subscription_id,
354 retention: None,
355 },
356 ),
357 ..Default::default()
358 },
359 ManageSubscriptionIntent::StopCancellation => unreachable!(),
360 };
361
362 let mut params = CreateBillingPortalSession::new(customer_id);
363 params.flow_data = Some(flow);
364 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
365 params.return_url = Some(&return_url);
366
367 let session = BillingPortalSession::create(&stripe_client, params).await?;
368
369 Ok(Json(ManageBillingSubscriptionResponse {
370 billing_portal_session_url: Some(session.url),
371 }))
372}
373
374/// The amount of time we wait in between each poll of Stripe events.
375///
376/// This value should strike a balance between:
377/// 1. Being short enough that we update quickly when something in Stripe changes
378/// 2. Being long enough that we don't eat into our rate limits.
379///
380/// As a point of reference, the Sequin folks say they have this at **500ms**:
381///
382/// > We poll the Stripe /events endpoint every 500ms per account
383/// >
384/// > — https://blog.sequinstream.com/events-not-webhooks/
385const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
386
387/// The maximum number of events to return per page.
388///
389/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
390///
391/// > Limit can range between 1 and 100, and the default is 10.
392const EVENTS_LIMIT_PER_PAGE: u64 = 100;
393
394/// The number of pages consisting entirely of already-processed events that we
395/// will see before we stop retrieving events.
396///
397/// This is used to prevent over-fetching the Stripe events API for events we've
398/// already seen and processed.
399const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
400
401/// Polls the Stripe events API periodically to reconcile the records in our
402/// database with the data in Stripe.
403pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
404 let Some(stripe_client) = app.stripe_client.clone() else {
405 log::warn!("failed to retrieve Stripe client");
406 return;
407 };
408
409 let executor = app.executor.clone();
410 executor.spawn_detached({
411 let executor = executor.clone();
412 async move {
413 loop {
414 poll_stripe_events(&app, &stripe_client).await.log_err();
415
416 executor.sleep(POLL_EVENTS_INTERVAL).await;
417 }
418 }
419 });
420}
421
422async fn poll_stripe_events(
423 app: &Arc<AppState>,
424 stripe_client: &stripe::Client,
425) -> anyhow::Result<()> {
426 fn event_type_to_string(event_type: EventType) -> String {
427 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
428 // so we need to unquote it.
429 event_type.to_string().trim_matches('"').to_string()
430 }
431
432 let event_types = [
433 EventType::CustomerCreated,
434 EventType::CustomerUpdated,
435 EventType::CustomerSubscriptionCreated,
436 EventType::CustomerSubscriptionUpdated,
437 EventType::CustomerSubscriptionPaused,
438 EventType::CustomerSubscriptionResumed,
439 EventType::CustomerSubscriptionDeleted,
440 ]
441 .into_iter()
442 .map(event_type_to_string)
443 .collect::<Vec<_>>();
444
445 let mut pages_of_already_processed_events = 0;
446 let mut unprocessed_events = Vec::new();
447
448 loop {
449 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP {
450 log::info!("saw {pages_of_already_processed_events} pages of already-processed events: stopping event retrieval");
451 break;
452 }
453
454 log::info!("retrieving events from Stripe: {}", event_types.join(", "));
455
456 let mut params = ListEvents::new();
457 params.types = Some(event_types.clone());
458 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
459
460 let events = stripe::Event::list(stripe_client, ¶ms).await?;
461
462 let processed_event_ids = {
463 let event_ids = &events
464 .data
465 .iter()
466 .map(|event| event.id.as_str())
467 .collect::<Vec<_>>();
468
469 app.db
470 .get_processed_stripe_events_by_event_ids(event_ids)
471 .await?
472 .into_iter()
473 .map(|event| event.stripe_event_id)
474 .collect::<Vec<_>>()
475 };
476
477 let mut processed_events_in_page = 0;
478 let events_in_page = events.data.len();
479 for event in events.data {
480 if processed_event_ids.contains(&event.id.to_string()) {
481 processed_events_in_page += 1;
482 log::debug!("Stripe event {} already processed: skipping", event.id);
483 } else {
484 unprocessed_events.push(event);
485 }
486 }
487
488 if processed_events_in_page == events_in_page {
489 pages_of_already_processed_events += 1;
490 }
491
492 if !events.has_more {
493 break;
494 }
495 }
496
497 log::info!(
498 "unprocessed events from Stripe: {}",
499 unprocessed_events.len()
500 );
501
502 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
503 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
504
505 for event in unprocessed_events {
506 let event_id = event.id.clone();
507 let processed_event_params = CreateProcessedStripeEventParams {
508 stripe_event_id: event.id.to_string(),
509 stripe_event_type: event_type_to_string(event.type_),
510 stripe_event_created_timestamp: event.created,
511 };
512
513 // If the event has happened too far in the past, we don't want to
514 // process it and risk overwriting other more-recent updates.
515 //
516 // 1 hour was chosen arbitrarily. This could be made longer or shorter.
517 let one_hour = Duration::from_secs(60 * 60);
518 let an_hour_ago = Utc::now() - one_hour;
519 if an_hour_ago.timestamp() > event.created {
520 log::info!(
521 "Stripe event {} is more than {one_hour:?} old, marking as processed",
522 event_id
523 );
524 app.db
525 .create_processed_stripe_event(&processed_event_params)
526 .await?;
527
528 return Ok(());
529 }
530
531 let process_result = match event.type_ {
532 EventType::CustomerCreated | EventType::CustomerUpdated => {
533 handle_customer_event(app, stripe_client, event).await
534 }
535 EventType::CustomerSubscriptionCreated
536 | EventType::CustomerSubscriptionUpdated
537 | EventType::CustomerSubscriptionPaused
538 | EventType::CustomerSubscriptionResumed
539 | EventType::CustomerSubscriptionDeleted => {
540 handle_customer_subscription_event(app, stripe_client, event).await
541 }
542 _ => Ok(()),
543 };
544
545 if let Some(()) = process_result
546 .with_context(|| format!("failed to process event {event_id} successfully"))
547 .log_err()
548 {
549 app.db
550 .create_processed_stripe_event(&processed_event_params)
551 .await?;
552 }
553 }
554
555 Ok(())
556}
557
558async fn handle_customer_event(
559 app: &Arc<AppState>,
560 _stripe_client: &stripe::Client,
561 event: stripe::Event,
562) -> anyhow::Result<()> {
563 let EventObject::Customer(customer) = event.data.object else {
564 bail!("unexpected event payload for {}", event.id);
565 };
566
567 log::info!("handling Stripe {} event: {}", event.type_, event.id);
568
569 let Some(email) = customer.email else {
570 log::info!("Stripe customer has no email: skipping");
571 return Ok(());
572 };
573
574 let Some(user) = app.db.get_user_by_email(&email).await? else {
575 log::info!("no user found for email: skipping");
576 return Ok(());
577 };
578
579 if let Some(existing_customer) = app
580 .db
581 .get_billing_customer_by_stripe_customer_id(&customer.id)
582 .await?
583 {
584 app.db
585 .update_billing_customer(
586 existing_customer.id,
587 &UpdateBillingCustomerParams {
588 // For now we just leave the information as-is, as it is not
589 // likely to change.
590 ..Default::default()
591 },
592 )
593 .await?;
594 } else {
595 app.db
596 .create_billing_customer(&CreateBillingCustomerParams {
597 user_id: user.id,
598 stripe_customer_id: customer.id.to_string(),
599 })
600 .await?;
601 }
602
603 Ok(())
604}
605
606async fn handle_customer_subscription_event(
607 app: &Arc<AppState>,
608 stripe_client: &stripe::Client,
609 event: stripe::Event,
610) -> anyhow::Result<()> {
611 let EventObject::Subscription(subscription) = event.data.object else {
612 bail!("unexpected event payload for {}", event.id);
613 };
614
615 log::info!("handling Stripe {} event: {}", event.type_, event.id);
616
617 let billing_customer =
618 find_or_create_billing_customer(app, stripe_client, subscription.customer)
619 .await?
620 .ok_or_else(|| anyhow!("billing customer not found"))?;
621
622 if let Some(existing_subscription) = app
623 .db
624 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
625 .await?
626 {
627 app.db
628 .update_billing_subscription(
629 existing_subscription.id,
630 &UpdateBillingSubscriptionParams {
631 billing_customer_id: ActiveValue::set(billing_customer.id),
632 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
633 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
634 stripe_cancel_at: ActiveValue::set(
635 subscription
636 .cancel_at
637 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
638 .map(|time| time.naive_utc()),
639 ),
640 },
641 )
642 .await?;
643 } else {
644 app.db
645 .create_billing_subscription(&CreateBillingSubscriptionParams {
646 billing_customer_id: billing_customer.id,
647 stripe_subscription_id: subscription.id.to_string(),
648 stripe_subscription_status: subscription.status.into(),
649 })
650 .await?;
651 }
652
653 Ok(())
654}
655
656impl From<SubscriptionStatus> for StripeSubscriptionStatus {
657 fn from(value: SubscriptionStatus) -> Self {
658 match value {
659 SubscriptionStatus::Incomplete => Self::Incomplete,
660 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
661 SubscriptionStatus::Trialing => Self::Trialing,
662 SubscriptionStatus::Active => Self::Active,
663 SubscriptionStatus::PastDue => Self::PastDue,
664 SubscriptionStatus::Canceled => Self::Canceled,
665 SubscriptionStatus::Unpaid => Self::Unpaid,
666 SubscriptionStatus::Paused => Self::Paused,
667 }
668 }
669}
670
671/// Finds or creates a billing customer using the provided customer.
672async fn find_or_create_billing_customer(
673 app: &Arc<AppState>,
674 stripe_client: &stripe::Client,
675 customer_or_id: Expandable<Customer>,
676) -> anyhow::Result<Option<billing_customer::Model>> {
677 let customer_id = match &customer_or_id {
678 Expandable::Id(id) => id,
679 Expandable::Object(customer) => customer.id.as_ref(),
680 };
681
682 // If we already have a billing customer record associated with the Stripe customer,
683 // there's nothing more we need to do.
684 if let Some(billing_customer) = app
685 .db
686 .get_billing_customer_by_stripe_customer_id(customer_id)
687 .await?
688 {
689 return Ok(Some(billing_customer));
690 }
691
692 // If all we have is a customer ID, resolve it to a full customer record by
693 // hitting the Stripe API.
694 let customer = match customer_or_id {
695 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
696 Expandable::Object(customer) => *customer,
697 };
698
699 let Some(email) = customer.email else {
700 return Ok(None);
701 };
702
703 let Some(user) = app.db.get_user_by_email(&email).await? else {
704 return Ok(None);
705 };
706
707 let billing_customer = app
708 .db
709 .create_billing_customer(&CreateBillingCustomerParams {
710 user_id: user.id,
711 stripe_customer_id: customer.id.to_string(),
712 })
713 .await?;
714
715 Ok(Some(billing_customer))
716}
717
718const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
719
720pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
721 let Some(stripe_client) = app.stripe_client.clone() else {
722 log::warn!("failed to retrieve Stripe client");
723 return;
724 };
725 let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
726 log::warn!("failed to retrieve Stripe LLM usage price ID");
727 return;
728 };
729
730 let executor = app.executor.clone();
731 executor.spawn_detached({
732 let executor = executor.clone();
733 async move {
734 loop {
735 sync_with_stripe(
736 &app,
737 &llm_db,
738 &stripe_client,
739 stripe_llm_usage_price_id.clone(),
740 )
741 .await
742 .trace_err();
743
744 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
745 }
746 }
747 });
748}
749
750async fn sync_with_stripe(
751 app: &Arc<AppState>,
752 llm_db: &LlmDatabase,
753 stripe_client: &stripe::Client,
754 stripe_llm_usage_price_id: Arc<str>,
755) -> anyhow::Result<()> {
756 let subscriptions = app.db.get_active_billing_subscriptions().await?;
757
758 for (customer, subscription) in subscriptions {
759 update_stripe_subscription(
760 llm_db,
761 stripe_client,
762 &stripe_llm_usage_price_id,
763 customer,
764 subscription,
765 )
766 .await
767 .log_err();
768 }
769
770 Ok(())
771}
772
773async fn update_stripe_subscription(
774 llm_db: &LlmDatabase,
775 stripe_client: &stripe::Client,
776 stripe_llm_usage_price_id: &Arc<str>,
777 customer: billing_customer::Model,
778 subscription: billing_subscription::Model,
779) -> Result<(), anyhow::Error> {
780 let monthly_spending = llm_db
781 .get_user_spending_for_month(customer.user_id, Utc::now())
782 .await?;
783 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
784 .context("failed to parse subscription ID")?;
785
786 let monthly_spending_over_free_tier =
787 monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
788
789 let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
790 let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
791
792 let mut update_params = stripe::UpdateSubscription {
793 proration_behavior: Some(
794 stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
795 ),
796 ..Default::default()
797 };
798
799 if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
800 item.price.as_ref().map_or(false, |price| {
801 price.id == stripe_llm_usage_price_id.as_ref()
802 })
803 }) {
804 update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
805 id: Some(existing_item.id.to_string()),
806 quantity: Some(new_quantity as u64),
807 ..Default::default()
808 }]);
809 } else {
810 update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
811 price: Some(stripe_llm_usage_price_id.to_string()),
812 quantity: Some(new_quantity as u64),
813 ..Default::default()
814 }]);
815 }
816
817 Subscription::update(stripe_client, &subscription_id, update_params).await?;
818 Ok(())
819}