1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
19 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
20};
21use util::ResultExt;
22
23use crate::api::events::SnowflakeRow;
24use crate::db::billing_subscription::{
25 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
26};
27use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
28use crate::rpc::{ResultExt as _, Server};
29use crate::{AppState, Cents, Error, Result};
30use crate::{db::UserId, llm::db::LlmDatabase};
31use crate::{
32 db::{
33 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
34 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
35 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
36 },
37 stripe_billing::StripeBilling,
38};
39
40pub fn router() -> Router {
41 Router::new()
42 .route(
43 "/billing/preferences",
44 get(get_billing_preferences).put(update_billing_preferences),
45 )
46 .route(
47 "/billing/subscriptions",
48 get(list_billing_subscriptions).post(create_billing_subscription),
49 )
50 .route(
51 "/billing/subscriptions/manage",
52 post(manage_billing_subscription),
53 )
54 .route("/billing/monthly_spend", get(get_monthly_spend))
55}
56
57#[derive(Debug, Deserialize)]
58struct GetBillingPreferencesParams {
59 github_user_id: i32,
60}
61
62#[derive(Debug, Serialize)]
63struct BillingPreferencesResponse {
64 max_monthly_llm_usage_spending_in_cents: i32,
65}
66
67async fn get_billing_preferences(
68 Extension(app): Extension<Arc<AppState>>,
69 Query(params): Query<GetBillingPreferencesParams>,
70) -> Result<Json<BillingPreferencesResponse>> {
71 let user = app
72 .db
73 .get_user_by_github_user_id(params.github_user_id)
74 .await?
75 .ok_or_else(|| anyhow!("user not found"))?;
76
77 let preferences = app.db.get_billing_preferences(user.id).await?;
78
79 Ok(Json(BillingPreferencesResponse {
80 max_monthly_llm_usage_spending_in_cents: preferences
81 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
82 preferences.max_monthly_llm_usage_spending_in_cents
83 }),
84 }))
85}
86
87#[derive(Debug, Deserialize)]
88struct UpdateBillingPreferencesBody {
89 github_user_id: i32,
90 max_monthly_llm_usage_spending_in_cents: i32,
91}
92
93async fn update_billing_preferences(
94 Extension(app): Extension<Arc<AppState>>,
95 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
96 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
97) -> Result<Json<BillingPreferencesResponse>> {
98 let user = app
99 .db
100 .get_user_by_github_user_id(body.github_user_id)
101 .await?
102 .ok_or_else(|| anyhow!("user not found"))?;
103
104 let max_monthly_llm_usage_spending_in_cents =
105 body.max_monthly_llm_usage_spending_in_cents.max(0);
106
107 let billing_preferences =
108 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
109 app.db
110 .update_billing_preferences(
111 user.id,
112 &UpdateBillingPreferencesParams {
113 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
114 max_monthly_llm_usage_spending_in_cents,
115 ),
116 },
117 )
118 .await?
119 } else {
120 app.db
121 .create_billing_preferences(
122 user.id,
123 &crate::db::CreateBillingPreferencesParams {
124 max_monthly_llm_usage_spending_in_cents,
125 },
126 )
127 .await?
128 };
129
130 SnowflakeRow::new(
131 "Spend Limit Updated",
132 Some(user.metrics_id),
133 user.admin,
134 None,
135 json!({
136 "user_id": user.id,
137 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
138 }),
139 )
140 .write(&app.kinesis_client, &app.config.kinesis_stream)
141 .await
142 .log_err();
143
144 rpc_server.refresh_llm_tokens_for_user(user.id).await;
145
146 Ok(Json(BillingPreferencesResponse {
147 max_monthly_llm_usage_spending_in_cents: billing_preferences
148 .max_monthly_llm_usage_spending_in_cents,
149 }))
150}
151
152#[derive(Debug, Deserialize)]
153struct ListBillingSubscriptionsParams {
154 github_user_id: i32,
155}
156
157#[derive(Debug, Serialize)]
158struct BillingSubscriptionJson {
159 id: BillingSubscriptionId,
160 name: String,
161 status: StripeSubscriptionStatus,
162 cancel_at: Option<String>,
163 /// Whether this subscription can be canceled.
164 is_cancelable: bool,
165}
166
167#[derive(Debug, Serialize)]
168struct ListBillingSubscriptionsResponse {
169 subscriptions: Vec<BillingSubscriptionJson>,
170}
171
172async fn list_billing_subscriptions(
173 Extension(app): Extension<Arc<AppState>>,
174 Query(params): Query<ListBillingSubscriptionsParams>,
175) -> Result<Json<ListBillingSubscriptionsResponse>> {
176 let user = app
177 .db
178 .get_user_by_github_user_id(params.github_user_id)
179 .await?
180 .ok_or_else(|| anyhow!("user not found"))?;
181
182 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
183
184 Ok(Json(ListBillingSubscriptionsResponse {
185 subscriptions: subscriptions
186 .into_iter()
187 .map(|subscription| BillingSubscriptionJson {
188 id: subscription.id,
189 name: match subscription.kind {
190 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
191 None => "Zed LLM Usage".to_string(),
192 },
193 status: subscription.stripe_subscription_status,
194 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
195 cancel_at
196 .and_utc()
197 .to_rfc3339_opts(SecondsFormat::Millis, true)
198 }),
199 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
200 && subscription.stripe_cancel_at.is_none(),
201 })
202 .collect(),
203 }))
204}
205
206#[derive(Debug, Clone, Copy, Deserialize)]
207#[serde(rename_all = "snake_case")]
208enum ProductCode {
209 ZedPro,
210}
211
212#[derive(Debug, Deserialize)]
213struct CreateBillingSubscriptionBody {
214 github_user_id: i32,
215 product: Option<ProductCode>,
216}
217
218#[derive(Debug, Serialize)]
219struct CreateBillingSubscriptionResponse {
220 checkout_session_url: String,
221}
222
223/// Initiates a Stripe Checkout session for creating a billing subscription.
224async fn create_billing_subscription(
225 Extension(app): Extension<Arc<AppState>>,
226 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
227) -> Result<Json<CreateBillingSubscriptionResponse>> {
228 let user = app
229 .db
230 .get_user_by_github_user_id(body.github_user_id)
231 .await?
232 .ok_or_else(|| anyhow!("user not found"))?;
233
234 let Some(stripe_client) = app.stripe_client.clone() else {
235 log::error!("failed to retrieve Stripe client");
236 Err(Error::http(
237 StatusCode::NOT_IMPLEMENTED,
238 "not supported".into(),
239 ))?
240 };
241 let Some(stripe_billing) = app.stripe_billing.clone() else {
242 log::error!("failed to retrieve Stripe billing object");
243 Err(Error::http(
244 StatusCode::NOT_IMPLEMENTED,
245 "not supported".into(),
246 ))?
247 };
248 let Some(llm_db) = app.llm_db.clone() else {
249 log::error!("failed to retrieve LLM database");
250 Err(Error::http(
251 StatusCode::NOT_IMPLEMENTED,
252 "not supported".into(),
253 ))?
254 };
255
256 if app.db.has_active_billing_subscription(user.id).await? {
257 return Err(Error::http(
258 StatusCode::CONFLICT,
259 "user already has an active subscription".into(),
260 ));
261 }
262
263 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
264 if let Some(existing_billing_customer) = &existing_billing_customer {
265 if existing_billing_customer.has_overdue_invoices {
266 return Err(Error::http(
267 StatusCode::PAYMENT_REQUIRED,
268 "user has overdue invoices".into(),
269 ));
270 }
271 }
272
273 let customer_id = if let Some(existing_customer) = existing_billing_customer {
274 CustomerId::from_str(&existing_customer.stripe_customer_id)
275 .context("failed to parse customer ID")?
276 } else {
277 let customer = Customer::create(
278 &stripe_client,
279 CreateCustomer {
280 email: user.email_address.as_deref(),
281 ..Default::default()
282 },
283 )
284 .await?;
285
286 customer.id
287 };
288
289 let checkout_session_url = match body.product {
290 Some(ProductCode::ZedPro) => {
291 let success_url = format!(
292 "{}/account?checkout_complete=1",
293 app.config.zed_dot_dev_url()
294 );
295 stripe_billing
296 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
297 .await?
298 }
299 None => {
300 let default_model =
301 llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
302 let stripe_model = stripe_billing.register_model(default_model).await?;
303 let success_url = format!(
304 "{}/account?checkout_complete=1",
305 app.config.zed_dot_dev_url()
306 );
307 stripe_billing
308 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
309 .await?
310 }
311 };
312
313 Ok(Json(CreateBillingSubscriptionResponse {
314 checkout_session_url,
315 }))
316}
317
318#[derive(Debug, PartialEq, Deserialize)]
319#[serde(rename_all = "snake_case")]
320enum ManageSubscriptionIntent {
321 /// The user intends to manage their subscription.
322 ///
323 /// This will open the Stripe billing portal without putting the user in a specific flow.
324 ManageSubscription,
325 /// The user intends to cancel their subscription.
326 Cancel,
327 /// The user intends to stop the cancellation of their subscription.
328 StopCancellation,
329}
330
331#[derive(Debug, Deserialize)]
332struct ManageBillingSubscriptionBody {
333 github_user_id: i32,
334 intent: ManageSubscriptionIntent,
335 /// The ID of the subscription to manage.
336 subscription_id: BillingSubscriptionId,
337}
338
339#[derive(Debug, Serialize)]
340struct ManageBillingSubscriptionResponse {
341 billing_portal_session_url: Option<String>,
342}
343
344/// Initiates a Stripe customer portal session for managing a billing subscription.
345async fn manage_billing_subscription(
346 Extension(app): Extension<Arc<AppState>>,
347 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
348) -> Result<Json<ManageBillingSubscriptionResponse>> {
349 let user = app
350 .db
351 .get_user_by_github_user_id(body.github_user_id)
352 .await?
353 .ok_or_else(|| anyhow!("user not found"))?;
354
355 let Some(stripe_client) = app.stripe_client.clone() else {
356 log::error!("failed to retrieve Stripe client");
357 Err(Error::http(
358 StatusCode::NOT_IMPLEMENTED,
359 "not supported".into(),
360 ))?
361 };
362
363 let customer = app
364 .db
365 .get_billing_customer_by_user_id(user.id)
366 .await?
367 .ok_or_else(|| anyhow!("billing customer not found"))?;
368 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
369 .context("failed to parse customer ID")?;
370
371 let subscription = app
372 .db
373 .get_billing_subscription_by_id(body.subscription_id)
374 .await?
375 .ok_or_else(|| anyhow!("subscription not found"))?;
376
377 if body.intent == ManageSubscriptionIntent::StopCancellation {
378 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
379 .context("failed to parse subscription ID")?;
380
381 let updated_stripe_subscription = Subscription::update(
382 &stripe_client,
383 &subscription_id,
384 stripe::UpdateSubscription {
385 cancel_at_period_end: Some(false),
386 ..Default::default()
387 },
388 )
389 .await?;
390
391 app.db
392 .update_billing_subscription(
393 subscription.id,
394 &UpdateBillingSubscriptionParams {
395 stripe_cancel_at: ActiveValue::set(
396 updated_stripe_subscription
397 .cancel_at
398 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
399 .map(|time| time.naive_utc()),
400 ),
401 ..Default::default()
402 },
403 )
404 .await?;
405
406 return Ok(Json(ManageBillingSubscriptionResponse {
407 billing_portal_session_url: None,
408 }));
409 }
410
411 let flow = match body.intent {
412 ManageSubscriptionIntent::ManageSubscription => None,
413 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
414 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
415 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
416 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
417 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
418 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
419 }),
420 ..Default::default()
421 }),
422 subscription_cancel: Some(
423 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
424 subscription: subscription.stripe_subscription_id,
425 retention: None,
426 },
427 ),
428 ..Default::default()
429 }),
430 ManageSubscriptionIntent::StopCancellation => unreachable!(),
431 };
432
433 let mut params = CreateBillingPortalSession::new(customer_id);
434 params.flow_data = flow;
435 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
436 params.return_url = Some(&return_url);
437
438 let session = BillingPortalSession::create(&stripe_client, params).await?;
439
440 Ok(Json(ManageBillingSubscriptionResponse {
441 billing_portal_session_url: Some(session.url),
442 }))
443}
444
445/// The amount of time we wait in between each poll of Stripe events.
446///
447/// This value should strike a balance between:
448/// 1. Being short enough that we update quickly when something in Stripe changes
449/// 2. Being long enough that we don't eat into our rate limits.
450///
451/// As a point of reference, the Sequin folks say they have this at **500ms**:
452///
453/// > We poll the Stripe /events endpoint every 500ms per account
454/// >
455/// > — https://blog.sequinstream.com/events-not-webhooks/
456const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
457
458/// The maximum number of events to return per page.
459///
460/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
461///
462/// > Limit can range between 1 and 100, and the default is 10.
463const EVENTS_LIMIT_PER_PAGE: u64 = 100;
464
465/// The number of pages consisting entirely of already-processed events that we
466/// will see before we stop retrieving events.
467///
468/// This is used to prevent over-fetching the Stripe events API for events we've
469/// already seen and processed.
470const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
471
472/// Polls the Stripe events API periodically to reconcile the records in our
473/// database with the data in Stripe.
474pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
475 let Some(stripe_client) = app.stripe_client.clone() else {
476 log::warn!("failed to retrieve Stripe client");
477 return;
478 };
479
480 let executor = app.executor.clone();
481 executor.spawn_detached({
482 let executor = executor.clone();
483 async move {
484 loop {
485 poll_stripe_events(&app, &rpc_server, &stripe_client)
486 .await
487 .log_err();
488
489 executor.sleep(POLL_EVENTS_INTERVAL).await;
490 }
491 }
492 });
493}
494
495async fn poll_stripe_events(
496 app: &Arc<AppState>,
497 rpc_server: &Arc<Server>,
498 stripe_client: &stripe::Client,
499) -> anyhow::Result<()> {
500 fn event_type_to_string(event_type: EventType) -> String {
501 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
502 // so we need to unquote it.
503 event_type.to_string().trim_matches('"').to_string()
504 }
505
506 let event_types = [
507 EventType::CustomerCreated,
508 EventType::CustomerUpdated,
509 EventType::CustomerSubscriptionCreated,
510 EventType::CustomerSubscriptionUpdated,
511 EventType::CustomerSubscriptionPaused,
512 EventType::CustomerSubscriptionResumed,
513 EventType::CustomerSubscriptionDeleted,
514 ]
515 .into_iter()
516 .map(event_type_to_string)
517 .collect::<Vec<_>>();
518
519 let mut pages_of_already_processed_events = 0;
520 let mut unprocessed_events = Vec::new();
521
522 log::info!(
523 "Stripe events: starting retrieval for {}",
524 event_types.join(", ")
525 );
526 let mut params = ListEvents::new();
527 params.types = Some(event_types.clone());
528 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
529
530 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
531 .await?
532 .paginate(params);
533
534 loop {
535 let processed_event_ids = {
536 let event_ids = event_pages
537 .page
538 .data
539 .iter()
540 .map(|event| event.id.as_str())
541 .collect::<Vec<_>>();
542 app.db
543 .get_processed_stripe_events_by_event_ids(&event_ids)
544 .await?
545 .into_iter()
546 .map(|event| event.stripe_event_id)
547 .collect::<Vec<_>>()
548 };
549
550 let mut processed_events_in_page = 0;
551 let events_in_page = event_pages.page.data.len();
552 for event in &event_pages.page.data {
553 if processed_event_ids.contains(&event.id.to_string()) {
554 processed_events_in_page += 1;
555 log::debug!("Stripe events: already processed '{}', skipping", event.id);
556 } else {
557 unprocessed_events.push(event.clone());
558 }
559 }
560
561 if processed_events_in_page == events_in_page {
562 pages_of_already_processed_events += 1;
563 }
564
565 if event_pages.page.has_more {
566 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
567 {
568 log::info!(
569 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
570 );
571 break;
572 } else {
573 log::info!("Stripe events: retrieving next page");
574 event_pages = event_pages.next(&stripe_client).await?;
575 }
576 } else {
577 break;
578 }
579 }
580
581 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
582
583 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
584 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
585
586 for event in unprocessed_events {
587 let event_id = event.id.clone();
588 let processed_event_params = CreateProcessedStripeEventParams {
589 stripe_event_id: event.id.to_string(),
590 stripe_event_type: event_type_to_string(event.type_),
591 stripe_event_created_timestamp: event.created,
592 };
593
594 // If the event has happened too far in the past, we don't want to
595 // process it and risk overwriting other more-recent updates.
596 //
597 // 1 day was chosen arbitrarily. This could be made longer or shorter.
598 let one_day = Duration::from_secs(24 * 60 * 60);
599 let a_day_ago = Utc::now() - one_day;
600 if a_day_ago.timestamp() > event.created {
601 log::info!(
602 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
603 event_id
604 );
605 app.db
606 .create_processed_stripe_event(&processed_event_params)
607 .await?;
608
609 return Ok(());
610 }
611
612 let process_result = match event.type_ {
613 EventType::CustomerCreated | EventType::CustomerUpdated => {
614 handle_customer_event(app, stripe_client, event).await
615 }
616 EventType::CustomerSubscriptionCreated
617 | EventType::CustomerSubscriptionUpdated
618 | EventType::CustomerSubscriptionPaused
619 | EventType::CustomerSubscriptionResumed
620 | EventType::CustomerSubscriptionDeleted => {
621 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
622 }
623 _ => Ok(()),
624 };
625
626 if let Some(()) = process_result
627 .with_context(|| format!("failed to process event {event_id} successfully"))
628 .log_err()
629 {
630 app.db
631 .create_processed_stripe_event(&processed_event_params)
632 .await?;
633 }
634 }
635
636 Ok(())
637}
638
639async fn handle_customer_event(
640 app: &Arc<AppState>,
641 _stripe_client: &stripe::Client,
642 event: stripe::Event,
643) -> anyhow::Result<()> {
644 let EventObject::Customer(customer) = event.data.object else {
645 bail!("unexpected event payload for {}", event.id);
646 };
647
648 log::info!("handling Stripe {} event: {}", event.type_, event.id);
649
650 let Some(email) = customer.email else {
651 log::info!("Stripe customer has no email: skipping");
652 return Ok(());
653 };
654
655 let Some(user) = app.db.get_user_by_email(&email).await? else {
656 log::info!("no user found for email: skipping");
657 return Ok(());
658 };
659
660 if let Some(existing_customer) = app
661 .db
662 .get_billing_customer_by_stripe_customer_id(&customer.id)
663 .await?
664 {
665 app.db
666 .update_billing_customer(
667 existing_customer.id,
668 &UpdateBillingCustomerParams {
669 // For now we just leave the information as-is, as it is not
670 // likely to change.
671 ..Default::default()
672 },
673 )
674 .await?;
675 } else {
676 app.db
677 .create_billing_customer(&CreateBillingCustomerParams {
678 user_id: user.id,
679 stripe_customer_id: customer.id.to_string(),
680 })
681 .await?;
682 }
683
684 Ok(())
685}
686
687async fn handle_customer_subscription_event(
688 app: &Arc<AppState>,
689 rpc_server: &Arc<Server>,
690 stripe_client: &stripe::Client,
691 event: stripe::Event,
692) -> anyhow::Result<()> {
693 let EventObject::Subscription(subscription) = event.data.object else {
694 bail!("unexpected event payload for {}", event.id);
695 };
696
697 log::info!("handling Stripe {} event: {}", event.type_, event.id);
698
699 let subscription_kind =
700 if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() {
701 let has_zed_pro_price = subscription.items.data.iter().any(|item| {
702 item.price
703 .as_ref()
704 .map_or(false, |price| price.id.as_str() == zed_pro_price_id)
705 });
706
707 if has_zed_pro_price {
708 Some(SubscriptionKind::ZedPro)
709 } else {
710 None
711 }
712 } else {
713 None
714 };
715
716 let billing_customer =
717 find_or_create_billing_customer(app, stripe_client, subscription.customer)
718 .await?
719 .ok_or_else(|| anyhow!("billing customer not found"))?;
720
721 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
722 && subscription
723 .cancellation_details
724 .as_ref()
725 .and_then(|details| details.reason)
726 .map_or(false, |reason| {
727 reason == CancellationDetailsReason::PaymentFailed
728 });
729
730 if was_canceled_due_to_payment_failure {
731 app.db
732 .update_billing_customer(
733 billing_customer.id,
734 &UpdateBillingCustomerParams {
735 has_overdue_invoices: ActiveValue::set(true),
736 ..Default::default()
737 },
738 )
739 .await?;
740 }
741
742 if let Some(existing_subscription) = app
743 .db
744 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
745 .await?
746 {
747 app.db
748 .update_billing_subscription(
749 existing_subscription.id,
750 &UpdateBillingSubscriptionParams {
751 billing_customer_id: ActiveValue::set(billing_customer.id),
752 kind: ActiveValue::set(subscription_kind),
753 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
754 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
755 stripe_cancel_at: ActiveValue::set(
756 subscription
757 .cancel_at
758 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
759 .map(|time| time.naive_utc()),
760 ),
761 stripe_cancellation_reason: ActiveValue::set(
762 subscription
763 .cancellation_details
764 .and_then(|details| details.reason)
765 .map(|reason| reason.into()),
766 ),
767 stripe_current_period_start: ActiveValue::set(Some(
768 subscription.current_period_start,
769 )),
770 stripe_current_period_end: ActiveValue::set(Some(
771 subscription.current_period_end,
772 )),
773 },
774 )
775 .await?;
776 } else {
777 // If the user already has an active billing subscription, ignore the
778 // event and return an `Ok` to signal that it was processed
779 // successfully.
780 //
781 // There is the possibility that this could cause us to not create a
782 // subscription in the following scenario:
783 //
784 // 1. User has an active subscription A
785 // 2. User cancels subscription A
786 // 3. User creates a new subscription B
787 // 4. We process the new subscription B before the cancellation of subscription A
788 // 5. User ends up with no subscriptions
789 //
790 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
791 if app
792 .db
793 .has_active_billing_subscription(billing_customer.user_id)
794 .await?
795 {
796 log::info!(
797 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
798 user_id = billing_customer.user_id,
799 subscription_id = subscription.id
800 );
801 return Ok(());
802 }
803
804 app.db
805 .create_billing_subscription(&CreateBillingSubscriptionParams {
806 billing_customer_id: billing_customer.id,
807 kind: subscription_kind,
808 stripe_subscription_id: subscription.id.to_string(),
809 stripe_subscription_status: subscription.status.into(),
810 stripe_cancellation_reason: subscription
811 .cancellation_details
812 .and_then(|details| details.reason)
813 .map(|reason| reason.into()),
814 stripe_current_period_start: Some(subscription.current_period_start),
815 stripe_current_period_end: Some(subscription.current_period_end),
816 })
817 .await?;
818 }
819
820 // When the user's subscription changes, we want to refresh their LLM tokens
821 // to either grant/revoke access.
822 rpc_server
823 .refresh_llm_tokens_for_user(billing_customer.user_id)
824 .await;
825
826 Ok(())
827}
828
829#[derive(Debug, Deserialize)]
830struct GetMonthlySpendParams {
831 github_user_id: i32,
832}
833
834#[derive(Debug, Serialize)]
835struct GetMonthlySpendResponse {
836 monthly_free_tier_spend_in_cents: u32,
837 monthly_free_tier_allowance_in_cents: u32,
838 monthly_spend_in_cents: u32,
839}
840
841async fn get_monthly_spend(
842 Extension(app): Extension<Arc<AppState>>,
843 Query(params): Query<GetMonthlySpendParams>,
844) -> Result<Json<GetMonthlySpendResponse>> {
845 let user = app
846 .db
847 .get_user_by_github_user_id(params.github_user_id)
848 .await?
849 .ok_or_else(|| anyhow!("user not found"))?;
850
851 let Some(llm_db) = app.llm_db.clone() else {
852 return Err(Error::http(
853 StatusCode::NOT_IMPLEMENTED,
854 "LLM database not available".into(),
855 ));
856 };
857
858 let free_tier = user
859 .custom_llm_monthly_allowance_in_cents
860 .map(|allowance| Cents(allowance as u32))
861 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
862
863 let spending_for_month = llm_db
864 .get_user_spending_for_month(user.id, Utc::now())
865 .await?;
866
867 let free_tier_spend = Cents::min(spending_for_month, free_tier);
868 let monthly_spend = spending_for_month.saturating_sub(free_tier);
869
870 Ok(Json(GetMonthlySpendResponse {
871 monthly_free_tier_spend_in_cents: free_tier_spend.0,
872 monthly_free_tier_allowance_in_cents: free_tier.0,
873 monthly_spend_in_cents: monthly_spend.0,
874 }))
875}
876
877impl From<SubscriptionStatus> for StripeSubscriptionStatus {
878 fn from(value: SubscriptionStatus) -> Self {
879 match value {
880 SubscriptionStatus::Incomplete => Self::Incomplete,
881 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
882 SubscriptionStatus::Trialing => Self::Trialing,
883 SubscriptionStatus::Active => Self::Active,
884 SubscriptionStatus::PastDue => Self::PastDue,
885 SubscriptionStatus::Canceled => Self::Canceled,
886 SubscriptionStatus::Unpaid => Self::Unpaid,
887 SubscriptionStatus::Paused => Self::Paused,
888 }
889 }
890}
891
892impl From<CancellationDetailsReason> for StripeCancellationReason {
893 fn from(value: CancellationDetailsReason) -> Self {
894 match value {
895 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
896 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
897 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
898 }
899 }
900}
901
902/// Finds or creates a billing customer using the provided customer.
903async fn find_or_create_billing_customer(
904 app: &Arc<AppState>,
905 stripe_client: &stripe::Client,
906 customer_or_id: Expandable<Customer>,
907) -> anyhow::Result<Option<billing_customer::Model>> {
908 let customer_id = match &customer_or_id {
909 Expandable::Id(id) => id,
910 Expandable::Object(customer) => customer.id.as_ref(),
911 };
912
913 // If we already have a billing customer record associated with the Stripe customer,
914 // there's nothing more we need to do.
915 if let Some(billing_customer) = app
916 .db
917 .get_billing_customer_by_stripe_customer_id(customer_id)
918 .await?
919 {
920 return Ok(Some(billing_customer));
921 }
922
923 // If all we have is a customer ID, resolve it to a full customer record by
924 // hitting the Stripe API.
925 let customer = match customer_or_id {
926 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
927 Expandable::Object(customer) => *customer,
928 };
929
930 let Some(email) = customer.email else {
931 return Ok(None);
932 };
933
934 let Some(user) = app.db.get_user_by_email(&email).await? else {
935 return Ok(None);
936 };
937
938 let billing_customer = app
939 .db
940 .create_billing_customer(&CreateBillingCustomerParams {
941 user_id: user.id,
942 stripe_customer_id: customer.id.to_string(),
943 })
944 .await?;
945
946 Ok(Some(billing_customer))
947}
948
949const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
950
951pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
952 let Some(stripe_billing) = app.stripe_billing.clone() else {
953 log::warn!("failed to retrieve Stripe billing object");
954 return;
955 };
956 let Some(llm_db) = app.llm_db.clone() else {
957 log::warn!("failed to retrieve LLM database");
958 return;
959 };
960
961 let executor = app.executor.clone();
962 executor.spawn_detached({
963 let executor = executor.clone();
964 async move {
965 loop {
966 sync_with_stripe(&app, &llm_db, &stripe_billing)
967 .await
968 .context("failed to sync LLM usage to Stripe")
969 .trace_err();
970 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
971 }
972 }
973 });
974}
975
976async fn sync_with_stripe(
977 app: &Arc<AppState>,
978 llm_db: &Arc<LlmDatabase>,
979 stripe_billing: &Arc<StripeBilling>,
980) -> anyhow::Result<()> {
981 let events = llm_db.get_billing_events().await?;
982 let user_ids = events
983 .iter()
984 .map(|(event, _)| event.user_id)
985 .collect::<HashSet<UserId>>();
986 let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
987
988 for (event, model) in events {
989 let Some((stripe_db_customer, stripe_db_subscription)) =
990 stripe_subscriptions.get(&event.user_id)
991 else {
992 tracing::warn!(
993 user_id = event.user_id.0,
994 "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
995 );
996 continue;
997 };
998 let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
999 .stripe_subscription_id
1000 .parse()
1001 .context("failed to parse stripe subscription id from db")?;
1002 let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1003 .stripe_customer_id
1004 .parse()
1005 .context("failed to parse stripe customer id from db")?;
1006
1007 let stripe_model = stripe_billing.register_model(&model).await?;
1008 stripe_billing
1009 .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1010 .await?;
1011 stripe_billing
1012 .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
1013 .await?;
1014 llm_db.consume_billing_event(event.id).await?;
1015 }
1016
1017 Ok(())
1018}