1use anyhow::{anyhow, bail, Context};
2use axum::{
3 extract::{self, Query},
4 routing::{get, post},
5 Extension, Json, Router,
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use std::{str::FromStr, sync::Arc, time::Duration};
13use stripe::{
14 BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
15 CreateBillingPortalSessionFlowDataAfterCompletion,
16 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
17 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
18 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
19};
20use util::ResultExt;
21
22use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
23use crate::rpc::{ResultExt as _, Server};
24use crate::{
25 db::{
26 billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
27 CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
28 UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
29 UpdateBillingSubscriptionParams,
30 },
31 stripe_billing::StripeBilling,
32};
33use crate::{
34 db::{billing_subscription::StripeSubscriptionStatus, UserId},
35 llm::db::LlmDatabase,
36};
37use crate::{AppState, Cents, Error, Result};
38
39pub fn router() -> Router {
40 Router::new()
41 .route(
42 "/billing/preferences",
43 get(get_billing_preferences).put(update_billing_preferences),
44 )
45 .route(
46 "/billing/subscriptions",
47 get(list_billing_subscriptions).post(create_billing_subscription),
48 )
49 .route(
50 "/billing/subscriptions/manage",
51 post(manage_billing_subscription),
52 )
53 .route("/billing/monthly_spend", get(get_monthly_spend))
54}
55
56#[derive(Debug, Deserialize)]
57struct GetBillingPreferencesParams {
58 github_user_id: i32,
59}
60
61#[derive(Debug, Serialize)]
62struct BillingPreferencesResponse {
63 max_monthly_llm_usage_spending_in_cents: i32,
64}
65
66async fn get_billing_preferences(
67 Extension(app): Extension<Arc<AppState>>,
68 Query(params): Query<GetBillingPreferencesParams>,
69) -> Result<Json<BillingPreferencesResponse>> {
70 let user = app
71 .db
72 .get_user_by_github_user_id(params.github_user_id)
73 .await?
74 .ok_or_else(|| anyhow!("user not found"))?;
75
76 let preferences = app.db.get_billing_preferences(user.id).await?;
77
78 Ok(Json(BillingPreferencesResponse {
79 max_monthly_llm_usage_spending_in_cents: preferences
80 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
81 preferences.max_monthly_llm_usage_spending_in_cents
82 }),
83 }))
84}
85
86#[derive(Debug, Deserialize)]
87struct UpdateBillingPreferencesBody {
88 github_user_id: i32,
89 max_monthly_llm_usage_spending_in_cents: i32,
90}
91
92async fn update_billing_preferences(
93 Extension(app): Extension<Arc<AppState>>,
94 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
95 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
96) -> Result<Json<BillingPreferencesResponse>> {
97 let user = app
98 .db
99 .get_user_by_github_user_id(body.github_user_id)
100 .await?
101 .ok_or_else(|| anyhow!("user not found"))?;
102
103 let billing_preferences =
104 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
105 app.db
106 .update_billing_preferences(
107 user.id,
108 &UpdateBillingPreferencesParams {
109 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
110 body.max_monthly_llm_usage_spending_in_cents,
111 ),
112 },
113 )
114 .await?
115 } else {
116 app.db
117 .create_billing_preferences(
118 user.id,
119 &crate::db::CreateBillingPreferencesParams {
120 max_monthly_llm_usage_spending_in_cents: body
121 .max_monthly_llm_usage_spending_in_cents,
122 },
123 )
124 .await?
125 };
126
127 rpc_server.refresh_llm_tokens_for_user(user.id).await;
128
129 Ok(Json(BillingPreferencesResponse {
130 max_monthly_llm_usage_spending_in_cents: billing_preferences
131 .max_monthly_llm_usage_spending_in_cents,
132 }))
133}
134
135#[derive(Debug, Deserialize)]
136struct ListBillingSubscriptionsParams {
137 github_user_id: i32,
138}
139
140#[derive(Debug, Serialize)]
141struct BillingSubscriptionJson {
142 id: BillingSubscriptionId,
143 name: String,
144 status: StripeSubscriptionStatus,
145 cancel_at: Option<String>,
146 /// Whether this subscription can be canceled.
147 is_cancelable: bool,
148}
149
150#[derive(Debug, Serialize)]
151struct ListBillingSubscriptionsResponse {
152 subscriptions: Vec<BillingSubscriptionJson>,
153}
154
155async fn list_billing_subscriptions(
156 Extension(app): Extension<Arc<AppState>>,
157 Query(params): Query<ListBillingSubscriptionsParams>,
158) -> Result<Json<ListBillingSubscriptionsResponse>> {
159 let user = app
160 .db
161 .get_user_by_github_user_id(params.github_user_id)
162 .await?
163 .ok_or_else(|| anyhow!("user not found"))?;
164
165 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
166
167 Ok(Json(ListBillingSubscriptionsResponse {
168 subscriptions: subscriptions
169 .into_iter()
170 .map(|subscription| BillingSubscriptionJson {
171 id: subscription.id,
172 name: "Zed LLM Usage".to_string(),
173 status: subscription.stripe_subscription_status,
174 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
175 cancel_at
176 .and_utc()
177 .to_rfc3339_opts(SecondsFormat::Millis, true)
178 }),
179 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
180 && subscription.stripe_cancel_at.is_none(),
181 })
182 .collect(),
183 }))
184}
185
186#[derive(Debug, Deserialize)]
187struct CreateBillingSubscriptionBody {
188 github_user_id: i32,
189}
190
191#[derive(Debug, Serialize)]
192struct CreateBillingSubscriptionResponse {
193 checkout_session_url: String,
194}
195
196/// Initiates a Stripe Checkout session for creating a billing subscription.
197async fn create_billing_subscription(
198 Extension(app): Extension<Arc<AppState>>,
199 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
200) -> Result<Json<CreateBillingSubscriptionResponse>> {
201 let user = app
202 .db
203 .get_user_by_github_user_id(body.github_user_id)
204 .await?
205 .ok_or_else(|| anyhow!("user not found"))?;
206
207 let Some(stripe_client) = app.stripe_client.clone() else {
208 log::error!("failed to retrieve Stripe client");
209 Err(Error::http(
210 StatusCode::NOT_IMPLEMENTED,
211 "not supported".into(),
212 ))?
213 };
214 let Some(stripe_billing) = app.stripe_billing.clone() else {
215 log::error!("failed to retrieve Stripe billing object");
216 Err(Error::http(
217 StatusCode::NOT_IMPLEMENTED,
218 "not supported".into(),
219 ))?
220 };
221 let Some(llm_db) = app.llm_db.clone() else {
222 log::error!("failed to retrieve LLM database");
223 Err(Error::http(
224 StatusCode::NOT_IMPLEMENTED,
225 "not supported".into(),
226 ))?
227 };
228
229 if app.db.has_active_billing_subscription(user.id).await? {
230 return Err(Error::http(
231 StatusCode::CONFLICT,
232 "user already has an active subscription".into(),
233 ));
234 }
235
236 let customer_id =
237 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
238 CustomerId::from_str(&existing_customer.stripe_customer_id)
239 .context("failed to parse customer ID")?
240 } else {
241 let customer = Customer::create(
242 &stripe_client,
243 CreateCustomer {
244 email: user.email_address.as_deref(),
245 ..Default::default()
246 },
247 )
248 .await?;
249
250 customer.id
251 };
252
253 let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
254 let stripe_model = stripe_billing.register_model(default_model).await?;
255 let success_url = format!(
256 "{}/account?checkout_complete=1",
257 app.config.zed_dot_dev_url()
258 );
259 let checkout_session_url = stripe_billing
260 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
261 .await?;
262 Ok(Json(CreateBillingSubscriptionResponse {
263 checkout_session_url,
264 }))
265}
266
267#[derive(Debug, PartialEq, Deserialize)]
268#[serde(rename_all = "snake_case")]
269enum ManageSubscriptionIntent {
270 /// The user intends to cancel their subscription.
271 Cancel,
272 /// The user intends to stop the cancellation of their subscription.
273 StopCancellation,
274}
275
276#[derive(Debug, Deserialize)]
277struct ManageBillingSubscriptionBody {
278 github_user_id: i32,
279 intent: ManageSubscriptionIntent,
280 /// The ID of the subscription to manage.
281 subscription_id: BillingSubscriptionId,
282}
283
284#[derive(Debug, Serialize)]
285struct ManageBillingSubscriptionResponse {
286 billing_portal_session_url: Option<String>,
287}
288
289/// Initiates a Stripe customer portal session for managing a billing subscription.
290async fn manage_billing_subscription(
291 Extension(app): Extension<Arc<AppState>>,
292 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
293) -> Result<Json<ManageBillingSubscriptionResponse>> {
294 let user = app
295 .db
296 .get_user_by_github_user_id(body.github_user_id)
297 .await?
298 .ok_or_else(|| anyhow!("user not found"))?;
299
300 let Some(stripe_client) = app.stripe_client.clone() else {
301 log::error!("failed to retrieve Stripe client");
302 Err(Error::http(
303 StatusCode::NOT_IMPLEMENTED,
304 "not supported".into(),
305 ))?
306 };
307
308 let customer = app
309 .db
310 .get_billing_customer_by_user_id(user.id)
311 .await?
312 .ok_or_else(|| anyhow!("billing customer not found"))?;
313 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
314 .context("failed to parse customer ID")?;
315
316 let subscription = app
317 .db
318 .get_billing_subscription_by_id(body.subscription_id)
319 .await?
320 .ok_or_else(|| anyhow!("subscription not found"))?;
321
322 if body.intent == ManageSubscriptionIntent::StopCancellation {
323 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
324 .context("failed to parse subscription ID")?;
325
326 let updated_stripe_subscription = Subscription::update(
327 &stripe_client,
328 &subscription_id,
329 stripe::UpdateSubscription {
330 cancel_at_period_end: Some(false),
331 ..Default::default()
332 },
333 )
334 .await?;
335
336 app.db
337 .update_billing_subscription(
338 subscription.id,
339 &UpdateBillingSubscriptionParams {
340 stripe_cancel_at: ActiveValue::set(
341 updated_stripe_subscription
342 .cancel_at
343 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
344 .map(|time| time.naive_utc()),
345 ),
346 ..Default::default()
347 },
348 )
349 .await?;
350
351 return Ok(Json(ManageBillingSubscriptionResponse {
352 billing_portal_session_url: None,
353 }));
354 }
355
356 let flow = match body.intent {
357 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
358 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
359 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
360 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
361 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
362 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
363 }),
364 ..Default::default()
365 }),
366 subscription_cancel: Some(
367 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
368 subscription: subscription.stripe_subscription_id,
369 retention: None,
370 },
371 ),
372 ..Default::default()
373 },
374 ManageSubscriptionIntent::StopCancellation => unreachable!(),
375 };
376
377 let mut params = CreateBillingPortalSession::new(customer_id);
378 params.flow_data = Some(flow);
379 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
380 params.return_url = Some(&return_url);
381
382 let session = BillingPortalSession::create(&stripe_client, params).await?;
383
384 Ok(Json(ManageBillingSubscriptionResponse {
385 billing_portal_session_url: Some(session.url),
386 }))
387}
388
389/// The amount of time we wait in between each poll of Stripe events.
390///
391/// This value should strike a balance between:
392/// 1. Being short enough that we update quickly when something in Stripe changes
393/// 2. Being long enough that we don't eat into our rate limits.
394///
395/// As a point of reference, the Sequin folks say they have this at **500ms**:
396///
397/// > We poll the Stripe /events endpoint every 500ms per account
398/// >
399/// > — https://blog.sequinstream.com/events-not-webhooks/
400const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
401
402/// The maximum number of events to return per page.
403///
404/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
405///
406/// > Limit can range between 1 and 100, and the default is 10.
407const EVENTS_LIMIT_PER_PAGE: u64 = 100;
408
409/// The number of pages consisting entirely of already-processed events that we
410/// will see before we stop retrieving events.
411///
412/// This is used to prevent over-fetching the Stripe events API for events we've
413/// already seen and processed.
414const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
415
416/// Polls the Stripe events API periodically to reconcile the records in our
417/// database with the data in Stripe.
418pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
419 let Some(stripe_client) = app.stripe_client.clone() else {
420 log::warn!("failed to retrieve Stripe client");
421 return;
422 };
423
424 let executor = app.executor.clone();
425 executor.spawn_detached({
426 let executor = executor.clone();
427 async move {
428 loop {
429 poll_stripe_events(&app, &rpc_server, &stripe_client)
430 .await
431 .log_err();
432
433 executor.sleep(POLL_EVENTS_INTERVAL).await;
434 }
435 }
436 });
437}
438
439async fn poll_stripe_events(
440 app: &Arc<AppState>,
441 rpc_server: &Arc<Server>,
442 stripe_client: &stripe::Client,
443) -> anyhow::Result<()> {
444 fn event_type_to_string(event_type: EventType) -> String {
445 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
446 // so we need to unquote it.
447 event_type.to_string().trim_matches('"').to_string()
448 }
449
450 let event_types = [
451 EventType::CustomerCreated,
452 EventType::CustomerUpdated,
453 EventType::CustomerSubscriptionCreated,
454 EventType::CustomerSubscriptionUpdated,
455 EventType::CustomerSubscriptionPaused,
456 EventType::CustomerSubscriptionResumed,
457 EventType::CustomerSubscriptionDeleted,
458 ]
459 .into_iter()
460 .map(event_type_to_string)
461 .collect::<Vec<_>>();
462
463 let mut pages_of_already_processed_events = 0;
464 let mut unprocessed_events = Vec::new();
465
466 log::info!(
467 "Stripe events: starting retrieval for {}",
468 event_types.join(", ")
469 );
470 let mut params = ListEvents::new();
471 params.types = Some(event_types.clone());
472 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
473
474 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
475 .await?
476 .paginate(params);
477
478 loop {
479 let processed_event_ids = {
480 let event_ids = event_pages
481 .page
482 .data
483 .iter()
484 .map(|event| event.id.as_str())
485 .collect::<Vec<_>>();
486 app.db
487 .get_processed_stripe_events_by_event_ids(&event_ids)
488 .await?
489 .into_iter()
490 .map(|event| event.stripe_event_id)
491 .collect::<Vec<_>>()
492 };
493
494 let mut processed_events_in_page = 0;
495 let events_in_page = event_pages.page.data.len();
496 for event in &event_pages.page.data {
497 if processed_event_ids.contains(&event.id.to_string()) {
498 processed_events_in_page += 1;
499 log::debug!("Stripe events: already processed '{}', skipping", event.id);
500 } else {
501 unprocessed_events.push(event.clone());
502 }
503 }
504
505 if processed_events_in_page == events_in_page {
506 pages_of_already_processed_events += 1;
507 }
508
509 if event_pages.page.has_more {
510 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
511 {
512 log::info!("Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events");
513 break;
514 } else {
515 log::info!("Stripe events: retrieving next page");
516 event_pages = event_pages.next(&stripe_client).await?;
517 }
518 } else {
519 break;
520 }
521 }
522
523 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
524
525 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
526 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
527
528 for event in unprocessed_events {
529 let event_id = event.id.clone();
530 let processed_event_params = CreateProcessedStripeEventParams {
531 stripe_event_id: event.id.to_string(),
532 stripe_event_type: event_type_to_string(event.type_),
533 stripe_event_created_timestamp: event.created,
534 };
535
536 // If the event has happened too far in the past, we don't want to
537 // process it and risk overwriting other more-recent updates.
538 //
539 // 1 day was chosen arbitrarily. This could be made longer or shorter.
540 let one_day = Duration::from_secs(24 * 60 * 60);
541 let a_day_ago = Utc::now() - one_day;
542 if a_day_ago.timestamp() > event.created {
543 log::info!(
544 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
545 event_id
546 );
547 app.db
548 .create_processed_stripe_event(&processed_event_params)
549 .await?;
550
551 return Ok(());
552 }
553
554 let process_result = match event.type_ {
555 EventType::CustomerCreated | EventType::CustomerUpdated => {
556 handle_customer_event(app, stripe_client, event).await
557 }
558 EventType::CustomerSubscriptionCreated
559 | EventType::CustomerSubscriptionUpdated
560 | EventType::CustomerSubscriptionPaused
561 | EventType::CustomerSubscriptionResumed
562 | EventType::CustomerSubscriptionDeleted => {
563 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
564 }
565 _ => Ok(()),
566 };
567
568 if let Some(()) = process_result
569 .with_context(|| format!("failed to process event {event_id} successfully"))
570 .log_err()
571 {
572 app.db
573 .create_processed_stripe_event(&processed_event_params)
574 .await?;
575 }
576 }
577
578 Ok(())
579}
580
581async fn handle_customer_event(
582 app: &Arc<AppState>,
583 _stripe_client: &stripe::Client,
584 event: stripe::Event,
585) -> anyhow::Result<()> {
586 let EventObject::Customer(customer) = event.data.object else {
587 bail!("unexpected event payload for {}", event.id);
588 };
589
590 log::info!("handling Stripe {} event: {}", event.type_, event.id);
591
592 let Some(email) = customer.email else {
593 log::info!("Stripe customer has no email: skipping");
594 return Ok(());
595 };
596
597 let Some(user) = app.db.get_user_by_email(&email).await? else {
598 log::info!("no user found for email: skipping");
599 return Ok(());
600 };
601
602 if let Some(existing_customer) = app
603 .db
604 .get_billing_customer_by_stripe_customer_id(&customer.id)
605 .await?
606 {
607 app.db
608 .update_billing_customer(
609 existing_customer.id,
610 &UpdateBillingCustomerParams {
611 // For now we just leave the information as-is, as it is not
612 // likely to change.
613 ..Default::default()
614 },
615 )
616 .await?;
617 } else {
618 app.db
619 .create_billing_customer(&CreateBillingCustomerParams {
620 user_id: user.id,
621 stripe_customer_id: customer.id.to_string(),
622 })
623 .await?;
624 }
625
626 Ok(())
627}
628
629async fn handle_customer_subscription_event(
630 app: &Arc<AppState>,
631 rpc_server: &Arc<Server>,
632 stripe_client: &stripe::Client,
633 event: stripe::Event,
634) -> anyhow::Result<()> {
635 let EventObject::Subscription(subscription) = event.data.object else {
636 bail!("unexpected event payload for {}", event.id);
637 };
638
639 log::info!("handling Stripe {} event: {}", event.type_, event.id);
640
641 let billing_customer =
642 find_or_create_billing_customer(app, stripe_client, subscription.customer)
643 .await?
644 .ok_or_else(|| anyhow!("billing customer not found"))?;
645
646 if let Some(existing_subscription) = app
647 .db
648 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
649 .await?
650 {
651 app.db
652 .update_billing_subscription(
653 existing_subscription.id,
654 &UpdateBillingSubscriptionParams {
655 billing_customer_id: ActiveValue::set(billing_customer.id),
656 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
657 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
658 stripe_cancel_at: ActiveValue::set(
659 subscription
660 .cancel_at
661 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
662 .map(|time| time.naive_utc()),
663 ),
664 },
665 )
666 .await?;
667 } else {
668 // If the user already has an active billing subscription, ignore the
669 // event and return an `Ok` to signal that it was processed
670 // successfully.
671 //
672 // There is the possibility that this could cause us to not create a
673 // subscription in the following scenario:
674 //
675 // 1. User has an active subscription A
676 // 2. User cancels subscription A
677 // 3. User creates a new subscription B
678 // 4. We process the new subscription B before the cancellation of subscription A
679 // 5. User ends up with no subscriptions
680 //
681 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
682 if app
683 .db
684 .has_active_billing_subscription(billing_customer.user_id)
685 .await?
686 {
687 log::info!(
688 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
689 user_id = billing_customer.user_id,
690 subscription_id = subscription.id
691 );
692 return Ok(());
693 }
694
695 app.db
696 .create_billing_subscription(&CreateBillingSubscriptionParams {
697 billing_customer_id: billing_customer.id,
698 stripe_subscription_id: subscription.id.to_string(),
699 stripe_subscription_status: subscription.status.into(),
700 })
701 .await?;
702 }
703
704 // When the user's subscription changes, we want to refresh their LLM tokens
705 // to either grant/revoke access.
706 rpc_server
707 .refresh_llm_tokens_for_user(billing_customer.user_id)
708 .await;
709
710 Ok(())
711}
712
713#[derive(Debug, Deserialize)]
714struct GetMonthlySpendParams {
715 github_user_id: i32,
716}
717
718#[derive(Debug, Serialize)]
719struct GetMonthlySpendResponse {
720 monthly_free_tier_spend_in_cents: u32,
721 monthly_free_tier_allowance_in_cents: u32,
722 monthly_spend_in_cents: u32,
723}
724
725async fn get_monthly_spend(
726 Extension(app): Extension<Arc<AppState>>,
727 Query(params): Query<GetMonthlySpendParams>,
728) -> Result<Json<GetMonthlySpendResponse>> {
729 let user = app
730 .db
731 .get_user_by_github_user_id(params.github_user_id)
732 .await?
733 .ok_or_else(|| anyhow!("user not found"))?;
734
735 let Some(llm_db) = app.llm_db.clone() else {
736 return Err(Error::http(
737 StatusCode::NOT_IMPLEMENTED,
738 "LLM database not available".into(),
739 ));
740 };
741
742 let free_tier = user
743 .custom_llm_monthly_allowance_in_cents
744 .map(|allowance| Cents(allowance as u32))
745 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
746
747 let spending_for_month = llm_db
748 .get_user_spending_for_month(user.id, Utc::now())
749 .await?;
750
751 let free_tier_spend = Cents::min(spending_for_month, free_tier);
752 let monthly_spend = spending_for_month.saturating_sub(free_tier);
753
754 Ok(Json(GetMonthlySpendResponse {
755 monthly_free_tier_spend_in_cents: free_tier_spend.0,
756 monthly_free_tier_allowance_in_cents: free_tier.0,
757 monthly_spend_in_cents: monthly_spend.0,
758 }))
759}
760
761impl From<SubscriptionStatus> for StripeSubscriptionStatus {
762 fn from(value: SubscriptionStatus) -> Self {
763 match value {
764 SubscriptionStatus::Incomplete => Self::Incomplete,
765 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
766 SubscriptionStatus::Trialing => Self::Trialing,
767 SubscriptionStatus::Active => Self::Active,
768 SubscriptionStatus::PastDue => Self::PastDue,
769 SubscriptionStatus::Canceled => Self::Canceled,
770 SubscriptionStatus::Unpaid => Self::Unpaid,
771 SubscriptionStatus::Paused => Self::Paused,
772 }
773 }
774}
775
776/// Finds or creates a billing customer using the provided customer.
777async fn find_or_create_billing_customer(
778 app: &Arc<AppState>,
779 stripe_client: &stripe::Client,
780 customer_or_id: Expandable<Customer>,
781) -> anyhow::Result<Option<billing_customer::Model>> {
782 let customer_id = match &customer_or_id {
783 Expandable::Id(id) => id,
784 Expandable::Object(customer) => customer.id.as_ref(),
785 };
786
787 // If we already have a billing customer record associated with the Stripe customer,
788 // there's nothing more we need to do.
789 if let Some(billing_customer) = app
790 .db
791 .get_billing_customer_by_stripe_customer_id(customer_id)
792 .await?
793 {
794 return Ok(Some(billing_customer));
795 }
796
797 // If all we have is a customer ID, resolve it to a full customer record by
798 // hitting the Stripe API.
799 let customer = match customer_or_id {
800 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
801 Expandable::Object(customer) => *customer,
802 };
803
804 let Some(email) = customer.email else {
805 return Ok(None);
806 };
807
808 let Some(user) = app.db.get_user_by_email(&email).await? else {
809 return Ok(None);
810 };
811
812 let billing_customer = app
813 .db
814 .create_billing_customer(&CreateBillingCustomerParams {
815 user_id: user.id,
816 stripe_customer_id: customer.id.to_string(),
817 })
818 .await?;
819
820 Ok(Some(billing_customer))
821}
822
823const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
824
825pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
826 let Some(stripe_billing) = app.stripe_billing.clone() else {
827 log::warn!("failed to retrieve Stripe billing object");
828 return;
829 };
830 let Some(llm_db) = app.llm_db.clone() else {
831 log::warn!("failed to retrieve LLM database");
832 return;
833 };
834
835 let executor = app.executor.clone();
836 executor.spawn_detached({
837 let executor = executor.clone();
838 async move {
839 loop {
840 sync_with_stripe(&app, &llm_db, &stripe_billing)
841 .await
842 .context("failed to sync LLM usage to Stripe")
843 .trace_err();
844 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
845 }
846 }
847 });
848}
849
850async fn sync_with_stripe(
851 app: &Arc<AppState>,
852 llm_db: &Arc<LlmDatabase>,
853 stripe_billing: &Arc<StripeBilling>,
854) -> anyhow::Result<()> {
855 let events = llm_db.get_billing_events().await?;
856 let user_ids = events
857 .iter()
858 .map(|(event, _)| event.user_id)
859 .collect::<HashSet<UserId>>();
860 let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
861
862 for (event, model) in events {
863 let Some((stripe_db_customer, stripe_db_subscription)) =
864 stripe_subscriptions.get(&event.user_id)
865 else {
866 tracing::warn!(
867 user_id = event.user_id.0,
868 "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
869 );
870 continue;
871 };
872 let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
873 .stripe_subscription_id
874 .parse()
875 .context("failed to parse stripe subscription id from db")?;
876 let stripe_customer_id: stripe::CustomerId = stripe_db_customer
877 .stripe_customer_id
878 .parse()
879 .context("failed to parse stripe customer id from db")?;
880
881 let stripe_model = stripe_billing.register_model(&model).await?;
882 stripe_billing
883 .subscribe_to_model(&stripe_subscription_id, &stripe_model)
884 .await?;
885 stripe_billing
886 .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
887 .await?;
888 llm_db.consume_billing_event(event.id).await?;
889 }
890
891 Ok(())
892}