billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3use std::time::Duration;
  4
  5use anyhow::{anyhow, bail, Context};
  6use axum::{
  7    extract::{self, Query},
  8    routing::{get, post},
  9    Extension, Json, Router,
 10};
 11use chrono::{DateTime, SecondsFormat, Utc};
 12use reqwest::StatusCode;
 13use sea_orm::ActiveValue;
 14use serde::{Deserialize, Serialize};
 15use stripe::{
 16    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 17    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 18    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 19    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 20    CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
 21    Subscription, SubscriptionId, SubscriptionStatus,
 22};
 23use util::ResultExt;
 24
 25use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
 26use crate::db::{
 27    billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
 28    CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
 29    UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
 30};
 31use crate::llm::db::LlmDatabase;
 32use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
 33use crate::rpc::ResultExt as _;
 34use crate::{AppState, Error, Result};
 35
 36pub fn router() -> Router {
 37    Router::new()
 38        .route(
 39            "/billing/preferences",
 40            get(get_billing_preferences).put(update_billing_preferences),
 41        )
 42        .route(
 43            "/billing/subscriptions",
 44            get(list_billing_subscriptions).post(create_billing_subscription),
 45        )
 46        .route(
 47            "/billing/subscriptions/manage",
 48            post(manage_billing_subscription),
 49        )
 50}
 51
 52#[derive(Debug, Deserialize)]
 53struct GetBillingPreferencesParams {
 54    github_user_id: i32,
 55}
 56
 57#[derive(Debug, Serialize)]
 58struct BillingPreferencesResponse {
 59    max_monthly_llm_usage_spending_in_cents: i32,
 60}
 61
 62async fn get_billing_preferences(
 63    Extension(app): Extension<Arc<AppState>>,
 64    Query(params): Query<GetBillingPreferencesParams>,
 65) -> Result<Json<BillingPreferencesResponse>> {
 66    let user = app
 67        .db
 68        .get_user_by_github_user_id(params.github_user_id)
 69        .await?
 70        .ok_or_else(|| anyhow!("user not found"))?;
 71
 72    let preferences = app.db.get_billing_preferences(user.id).await?;
 73
 74    Ok(Json(BillingPreferencesResponse {
 75        max_monthly_llm_usage_spending_in_cents: preferences
 76            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
 77                preferences.max_monthly_llm_usage_spending_in_cents
 78            }),
 79    }))
 80}
 81
 82#[derive(Debug, Deserialize)]
 83struct UpdateBillingPreferencesBody {
 84    github_user_id: i32,
 85    max_monthly_llm_usage_spending_in_cents: i32,
 86}
 87
 88async fn update_billing_preferences(
 89    Extension(app): Extension<Arc<AppState>>,
 90    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 91) -> Result<Json<BillingPreferencesResponse>> {
 92    let user = app
 93        .db
 94        .get_user_by_github_user_id(body.github_user_id)
 95        .await?
 96        .ok_or_else(|| anyhow!("user not found"))?;
 97
 98    let billing_preferences =
 99        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
100            app.db
101                .update_billing_preferences(
102                    user.id,
103                    &UpdateBillingPreferencesParams {
104                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
105                            body.max_monthly_llm_usage_spending_in_cents,
106                        ),
107                    },
108                )
109                .await?
110        } else {
111            app.db
112                .create_billing_preferences(
113                    user.id,
114                    &crate::db::CreateBillingPreferencesParams {
115                        max_monthly_llm_usage_spending_in_cents: body
116                            .max_monthly_llm_usage_spending_in_cents,
117                    },
118                )
119                .await?
120        };
121
122    Ok(Json(BillingPreferencesResponse {
123        max_monthly_llm_usage_spending_in_cents: billing_preferences
124            .max_monthly_llm_usage_spending_in_cents,
125    }))
126}
127
128#[derive(Debug, Deserialize)]
129struct ListBillingSubscriptionsParams {
130    github_user_id: i32,
131}
132
133#[derive(Debug, Serialize)]
134struct BillingSubscriptionJson {
135    id: BillingSubscriptionId,
136    name: String,
137    status: StripeSubscriptionStatus,
138    cancel_at: Option<String>,
139    /// Whether this subscription can be canceled.
140    is_cancelable: bool,
141}
142
143#[derive(Debug, Serialize)]
144struct ListBillingSubscriptionsResponse {
145    subscriptions: Vec<BillingSubscriptionJson>,
146}
147
148async fn list_billing_subscriptions(
149    Extension(app): Extension<Arc<AppState>>,
150    Query(params): Query<ListBillingSubscriptionsParams>,
151) -> Result<Json<ListBillingSubscriptionsResponse>> {
152    let user = app
153        .db
154        .get_user_by_github_user_id(params.github_user_id)
155        .await?
156        .ok_or_else(|| anyhow!("user not found"))?;
157
158    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
159
160    Ok(Json(ListBillingSubscriptionsResponse {
161        subscriptions: subscriptions
162            .into_iter()
163            .map(|subscription| BillingSubscriptionJson {
164                id: subscription.id,
165                name: "Zed LLM Usage".to_string(),
166                status: subscription.stripe_subscription_status,
167                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
168                    cancel_at
169                        .and_utc()
170                        .to_rfc3339_opts(SecondsFormat::Millis, true)
171                }),
172                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
173                    && subscription.stripe_cancel_at.is_none(),
174            })
175            .collect(),
176    }))
177}
178
179#[derive(Debug, Deserialize)]
180struct CreateBillingSubscriptionBody {
181    github_user_id: i32,
182}
183
184#[derive(Debug, Serialize)]
185struct CreateBillingSubscriptionResponse {
186    checkout_session_url: String,
187}
188
189/// Initiates a Stripe Checkout session for creating a billing subscription.
190async fn create_billing_subscription(
191    Extension(app): Extension<Arc<AppState>>,
192    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
193) -> Result<Json<CreateBillingSubscriptionResponse>> {
194    let user = app
195        .db
196        .get_user_by_github_user_id(body.github_user_id)
197        .await?
198        .ok_or_else(|| anyhow!("user not found"))?;
199
200    let Some((stripe_client, stripe_access_price_id)) = app
201        .stripe_client
202        .clone()
203        .zip(app.config.stripe_llm_access_price_id.clone())
204    else {
205        log::error!("failed to retrieve Stripe client or price ID");
206        Err(Error::http(
207            StatusCode::NOT_IMPLEMENTED,
208            "not supported".into(),
209        ))?
210    };
211
212    let customer_id =
213        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
214            CustomerId::from_str(&existing_customer.stripe_customer_id)
215                .context("failed to parse customer ID")?
216        } else {
217            let customer = Customer::create(
218                &stripe_client,
219                CreateCustomer {
220                    email: user.email_address.as_deref(),
221                    ..Default::default()
222                },
223            )
224            .await?;
225
226            customer.id
227        };
228
229    let checkout_session = {
230        let mut params = CreateCheckoutSession::new();
231        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
232        params.customer = Some(customer_id);
233        params.client_reference_id = Some(user.github_login.as_str());
234        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
235            price: Some(stripe_access_price_id.to_string()),
236            quantity: Some(1),
237            ..Default::default()
238        }]);
239        let success_url = format!("{}/account", app.config.zed_dot_dev_url());
240        params.success_url = Some(&success_url);
241
242        CheckoutSession::create(&stripe_client, params).await?
243    };
244
245    Ok(Json(CreateBillingSubscriptionResponse {
246        checkout_session_url: checkout_session
247            .url
248            .ok_or_else(|| anyhow!("no checkout session URL"))?,
249    }))
250}
251
252#[derive(Debug, PartialEq, Deserialize)]
253#[serde(rename_all = "snake_case")]
254enum ManageSubscriptionIntent {
255    /// The user intends to cancel their subscription.
256    Cancel,
257    /// The user intends to stop the cancellation of their subscription.
258    StopCancellation,
259}
260
261#[derive(Debug, Deserialize)]
262struct ManageBillingSubscriptionBody {
263    github_user_id: i32,
264    intent: ManageSubscriptionIntent,
265    /// The ID of the subscription to manage.
266    subscription_id: BillingSubscriptionId,
267}
268
269#[derive(Debug, Serialize)]
270struct ManageBillingSubscriptionResponse {
271    billing_portal_session_url: Option<String>,
272}
273
274/// Initiates a Stripe customer portal session for managing a billing subscription.
275async fn manage_billing_subscription(
276    Extension(app): Extension<Arc<AppState>>,
277    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
278) -> Result<Json<ManageBillingSubscriptionResponse>> {
279    let user = app
280        .db
281        .get_user_by_github_user_id(body.github_user_id)
282        .await?
283        .ok_or_else(|| anyhow!("user not found"))?;
284
285    let Some(stripe_client) = app.stripe_client.clone() else {
286        log::error!("failed to retrieve Stripe client");
287        Err(Error::http(
288            StatusCode::NOT_IMPLEMENTED,
289            "not supported".into(),
290        ))?
291    };
292
293    let customer = app
294        .db
295        .get_billing_customer_by_user_id(user.id)
296        .await?
297        .ok_or_else(|| anyhow!("billing customer not found"))?;
298    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
299        .context("failed to parse customer ID")?;
300
301    let subscription = app
302        .db
303        .get_billing_subscription_by_id(body.subscription_id)
304        .await?
305        .ok_or_else(|| anyhow!("subscription not found"))?;
306
307    if body.intent == ManageSubscriptionIntent::StopCancellation {
308        let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
309            .context("failed to parse subscription ID")?;
310
311        let updated_stripe_subscription = Subscription::update(
312            &stripe_client,
313            &subscription_id,
314            stripe::UpdateSubscription {
315                cancel_at_period_end: Some(false),
316                ..Default::default()
317            },
318        )
319        .await?;
320
321        app.db
322            .update_billing_subscription(
323                subscription.id,
324                &UpdateBillingSubscriptionParams {
325                    stripe_cancel_at: ActiveValue::set(
326                        updated_stripe_subscription
327                            .cancel_at
328                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
329                            .map(|time| time.naive_utc()),
330                    ),
331                    ..Default::default()
332                },
333            )
334            .await?;
335
336        return Ok(Json(ManageBillingSubscriptionResponse {
337            billing_portal_session_url: None,
338        }));
339    }
340
341    let flow = match body.intent {
342        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
343            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
344            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
345                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
346                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
347                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
348                }),
349                ..Default::default()
350            }),
351            subscription_cancel: Some(
352                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
353                    subscription: subscription.stripe_subscription_id,
354                    retention: None,
355                },
356            ),
357            ..Default::default()
358        },
359        ManageSubscriptionIntent::StopCancellation => unreachable!(),
360    };
361
362    let mut params = CreateBillingPortalSession::new(customer_id);
363    params.flow_data = Some(flow);
364    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
365    params.return_url = Some(&return_url);
366
367    let session = BillingPortalSession::create(&stripe_client, params).await?;
368
369    Ok(Json(ManageBillingSubscriptionResponse {
370        billing_portal_session_url: Some(session.url),
371    }))
372}
373
374/// The amount of time we wait in between each poll of Stripe events.
375///
376/// This value should strike a balance between:
377///   1. Being short enough that we update quickly when something in Stripe changes
378///   2. Being long enough that we don't eat into our rate limits.
379///
380/// As a point of reference, the Sequin folks say they have this at **500ms**:
381///
382/// > We poll the Stripe /events endpoint every 500ms per account
383/// >
384/// > — https://blog.sequinstream.com/events-not-webhooks/
385const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
386
387/// The maximum number of events to return per page.
388///
389/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
390///
391/// > Limit can range between 1 and 100, and the default is 10.
392const EVENTS_LIMIT_PER_PAGE: u64 = 100;
393
394/// The number of pages consisting entirely of already-processed events that we
395/// will see before we stop retrieving events.
396///
397/// This is used to prevent over-fetching the Stripe events API for events we've
398/// already seen and processed.
399const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
400
401/// Polls the Stripe events API periodically to reconcile the records in our
402/// database with the data in Stripe.
403pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
404    let Some(stripe_client) = app.stripe_client.clone() else {
405        log::warn!("failed to retrieve Stripe client");
406        return;
407    };
408
409    let executor = app.executor.clone();
410    executor.spawn_detached({
411        let executor = executor.clone();
412        async move {
413            loop {
414                poll_stripe_events(&app, &stripe_client).await.log_err();
415
416                executor.sleep(POLL_EVENTS_INTERVAL).await;
417            }
418        }
419    });
420}
421
422async fn poll_stripe_events(
423    app: &Arc<AppState>,
424    stripe_client: &stripe::Client,
425) -> anyhow::Result<()> {
426    fn event_type_to_string(event_type: EventType) -> String {
427        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
428        // so we need to unquote it.
429        event_type.to_string().trim_matches('"').to_string()
430    }
431
432    let event_types = [
433        EventType::CustomerCreated,
434        EventType::CustomerUpdated,
435        EventType::CustomerSubscriptionCreated,
436        EventType::CustomerSubscriptionUpdated,
437        EventType::CustomerSubscriptionPaused,
438        EventType::CustomerSubscriptionResumed,
439        EventType::CustomerSubscriptionDeleted,
440    ]
441    .into_iter()
442    .map(event_type_to_string)
443    .collect::<Vec<_>>();
444
445    let mut pages_of_already_processed_events = 0;
446    let mut unprocessed_events = Vec::new();
447
448    loop {
449        if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP {
450            log::info!("saw {pages_of_already_processed_events} pages of already-processed events: stopping event retrieval");
451            break;
452        }
453
454        log::info!("retrieving events from Stripe: {}", event_types.join(", "));
455
456        let mut params = ListEvents::new();
457        params.types = Some(event_types.clone());
458        params.limit = Some(EVENTS_LIMIT_PER_PAGE);
459
460        let events = stripe::Event::list(stripe_client, &params).await?;
461
462        let processed_event_ids = {
463            let event_ids = &events
464                .data
465                .iter()
466                .map(|event| event.id.as_str())
467                .collect::<Vec<_>>();
468
469            app.db
470                .get_processed_stripe_events_by_event_ids(event_ids)
471                .await?
472                .into_iter()
473                .map(|event| event.stripe_event_id)
474                .collect::<Vec<_>>()
475        };
476
477        let mut processed_events_in_page = 0;
478        let events_in_page = events.data.len();
479        for event in events.data {
480            if processed_event_ids.contains(&event.id.to_string()) {
481                processed_events_in_page += 1;
482                log::debug!("Stripe event {} already processed: skipping", event.id);
483            } else {
484                unprocessed_events.push(event);
485            }
486        }
487
488        if processed_events_in_page == events_in_page {
489            pages_of_already_processed_events += 1;
490        }
491
492        if !events.has_more {
493            break;
494        }
495    }
496
497    log::info!(
498        "unprocessed events from Stripe: {}",
499        unprocessed_events.len()
500    );
501
502    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
503    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
504
505    for event in unprocessed_events {
506        let event_id = event.id.clone();
507        let processed_event_params = CreateProcessedStripeEventParams {
508            stripe_event_id: event.id.to_string(),
509            stripe_event_type: event_type_to_string(event.type_),
510            stripe_event_created_timestamp: event.created,
511        };
512
513        // If the event has happened too far in the past, we don't want to
514        // process it and risk overwriting other more-recent updates.
515        //
516        // 1 hour was chosen arbitrarily. This could be made longer or shorter.
517        let one_hour = Duration::from_secs(60 * 60);
518        let an_hour_ago = Utc::now() - one_hour;
519        if an_hour_ago.timestamp() > event.created {
520            log::info!(
521                "Stripe event {} is more than {one_hour:?} old, marking as processed",
522                event_id
523            );
524            app.db
525                .create_processed_stripe_event(&processed_event_params)
526                .await?;
527
528            return Ok(());
529        }
530
531        let process_result = match event.type_ {
532            EventType::CustomerCreated | EventType::CustomerUpdated => {
533                handle_customer_event(app, stripe_client, event).await
534            }
535            EventType::CustomerSubscriptionCreated
536            | EventType::CustomerSubscriptionUpdated
537            | EventType::CustomerSubscriptionPaused
538            | EventType::CustomerSubscriptionResumed
539            | EventType::CustomerSubscriptionDeleted => {
540                handle_customer_subscription_event(app, stripe_client, event).await
541            }
542            _ => Ok(()),
543        };
544
545        if let Some(()) = process_result
546            .with_context(|| format!("failed to process event {event_id} successfully"))
547            .log_err()
548        {
549            app.db
550                .create_processed_stripe_event(&processed_event_params)
551                .await?;
552        }
553    }
554
555    Ok(())
556}
557
558async fn handle_customer_event(
559    app: &Arc<AppState>,
560    _stripe_client: &stripe::Client,
561    event: stripe::Event,
562) -> anyhow::Result<()> {
563    let EventObject::Customer(customer) = event.data.object else {
564        bail!("unexpected event payload for {}", event.id);
565    };
566
567    log::info!("handling Stripe {} event: {}", event.type_, event.id);
568
569    let Some(email) = customer.email else {
570        log::info!("Stripe customer has no email: skipping");
571        return Ok(());
572    };
573
574    let Some(user) = app.db.get_user_by_email(&email).await? else {
575        log::info!("no user found for email: skipping");
576        return Ok(());
577    };
578
579    if let Some(existing_customer) = app
580        .db
581        .get_billing_customer_by_stripe_customer_id(&customer.id)
582        .await?
583    {
584        app.db
585            .update_billing_customer(
586                existing_customer.id,
587                &UpdateBillingCustomerParams {
588                    // For now we just leave the information as-is, as it is not
589                    // likely to change.
590                    ..Default::default()
591                },
592            )
593            .await?;
594    } else {
595        app.db
596            .create_billing_customer(&CreateBillingCustomerParams {
597                user_id: user.id,
598                stripe_customer_id: customer.id.to_string(),
599            })
600            .await?;
601    }
602
603    Ok(())
604}
605
606async fn handle_customer_subscription_event(
607    app: &Arc<AppState>,
608    stripe_client: &stripe::Client,
609    event: stripe::Event,
610) -> anyhow::Result<()> {
611    let EventObject::Subscription(subscription) = event.data.object else {
612        bail!("unexpected event payload for {}", event.id);
613    };
614
615    log::info!("handling Stripe {} event: {}", event.type_, event.id);
616
617    let billing_customer =
618        find_or_create_billing_customer(app, stripe_client, subscription.customer)
619            .await?
620            .ok_or_else(|| anyhow!("billing customer not found"))?;
621
622    if let Some(existing_subscription) = app
623        .db
624        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
625        .await?
626    {
627        app.db
628            .update_billing_subscription(
629                existing_subscription.id,
630                &UpdateBillingSubscriptionParams {
631                    billing_customer_id: ActiveValue::set(billing_customer.id),
632                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
633                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
634                    stripe_cancel_at: ActiveValue::set(
635                        subscription
636                            .cancel_at
637                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
638                            .map(|time| time.naive_utc()),
639                    ),
640                },
641            )
642            .await?;
643    } else {
644        app.db
645            .create_billing_subscription(&CreateBillingSubscriptionParams {
646                billing_customer_id: billing_customer.id,
647                stripe_subscription_id: subscription.id.to_string(),
648                stripe_subscription_status: subscription.status.into(),
649            })
650            .await?;
651    }
652
653    Ok(())
654}
655
656impl From<SubscriptionStatus> for StripeSubscriptionStatus {
657    fn from(value: SubscriptionStatus) -> Self {
658        match value {
659            SubscriptionStatus::Incomplete => Self::Incomplete,
660            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
661            SubscriptionStatus::Trialing => Self::Trialing,
662            SubscriptionStatus::Active => Self::Active,
663            SubscriptionStatus::PastDue => Self::PastDue,
664            SubscriptionStatus::Canceled => Self::Canceled,
665            SubscriptionStatus::Unpaid => Self::Unpaid,
666            SubscriptionStatus::Paused => Self::Paused,
667        }
668    }
669}
670
671/// Finds or creates a billing customer using the provided customer.
672async fn find_or_create_billing_customer(
673    app: &Arc<AppState>,
674    stripe_client: &stripe::Client,
675    customer_or_id: Expandable<Customer>,
676) -> anyhow::Result<Option<billing_customer::Model>> {
677    let customer_id = match &customer_or_id {
678        Expandable::Id(id) => id,
679        Expandable::Object(customer) => customer.id.as_ref(),
680    };
681
682    // If we already have a billing customer record associated with the Stripe customer,
683    // there's nothing more we need to do.
684    if let Some(billing_customer) = app
685        .db
686        .get_billing_customer_by_stripe_customer_id(customer_id)
687        .await?
688    {
689        return Ok(Some(billing_customer));
690    }
691
692    // If all we have is a customer ID, resolve it to a full customer record by
693    // hitting the Stripe API.
694    let customer = match customer_or_id {
695        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
696        Expandable::Object(customer) => *customer,
697    };
698
699    let Some(email) = customer.email else {
700        return Ok(None);
701    };
702
703    let Some(user) = app.db.get_user_by_email(&email).await? else {
704        return Ok(None);
705    };
706
707    let billing_customer = app
708        .db
709        .create_billing_customer(&CreateBillingCustomerParams {
710            user_id: user.id,
711            stripe_customer_id: customer.id.to_string(),
712        })
713        .await?;
714
715    Ok(Some(billing_customer))
716}
717
718const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
719
720pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
721    let Some(stripe_client) = app.stripe_client.clone() else {
722        log::warn!("failed to retrieve Stripe client");
723        return;
724    };
725    let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
726        log::warn!("failed to retrieve Stripe LLM usage price ID");
727        return;
728    };
729
730    let executor = app.executor.clone();
731    executor.spawn_detached({
732        let executor = executor.clone();
733        async move {
734            loop {
735                sync_with_stripe(
736                    &app,
737                    &llm_db,
738                    &stripe_client,
739                    stripe_llm_usage_price_id.clone(),
740                )
741                .await
742                .trace_err();
743
744                executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
745            }
746        }
747    });
748}
749
750async fn sync_with_stripe(
751    app: &Arc<AppState>,
752    llm_db: &LlmDatabase,
753    stripe_client: &stripe::Client,
754    stripe_llm_usage_price_id: Arc<str>,
755) -> anyhow::Result<()> {
756    let subscriptions = app.db.get_active_billing_subscriptions().await?;
757
758    for (customer, subscription) in subscriptions {
759        update_stripe_subscription(
760            llm_db,
761            stripe_client,
762            &stripe_llm_usage_price_id,
763            customer,
764            subscription,
765        )
766        .await
767        .log_err();
768    }
769
770    Ok(())
771}
772
773async fn update_stripe_subscription(
774    llm_db: &LlmDatabase,
775    stripe_client: &stripe::Client,
776    stripe_llm_usage_price_id: &Arc<str>,
777    customer: billing_customer::Model,
778    subscription: billing_subscription::Model,
779) -> Result<(), anyhow::Error> {
780    let monthly_spending = llm_db
781        .get_user_spending_for_month(customer.user_id, Utc::now())
782        .await?;
783    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
784        .context("failed to parse subscription ID")?;
785
786    let monthly_spending_over_free_tier =
787        monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
788
789    let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
790    let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
791
792    let mut update_params = stripe::UpdateSubscription {
793        proration_behavior: Some(
794            stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
795        ),
796        ..Default::default()
797    };
798
799    if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
800        item.price.as_ref().map_or(false, |price| {
801            price.id == stripe_llm_usage_price_id.as_ref()
802        })
803    }) {
804        update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
805            id: Some(existing_item.id.to_string()),
806            quantity: Some(new_quantity as u64),
807            ..Default::default()
808        }]);
809    } else {
810        update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
811            price: Some(stripe_llm_usage_price_id.to_string()),
812            quantity: Some(new_quantity as u64),
813            ..Default::default()
814        }]);
815    }
816
817    Subscription::update(stripe_client, &subscription_id, update_params).await?;
818    Ok(())
819}