1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
30use crate::rpc::{ResultExt as _, Server};
31use crate::{AppState, Cents, Error, Result};
32use crate::{db::UserId, llm::db::LlmDatabase};
33use crate::{
34 db::{
35 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
36 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
37 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
38 },
39 stripe_billing::StripeBilling,
40};
41
42pub fn router() -> Router {
43 Router::new()
44 .route(
45 "/billing/preferences",
46 get(get_billing_preferences).put(update_billing_preferences),
47 )
48 .route(
49 "/billing/subscriptions",
50 get(list_billing_subscriptions).post(create_billing_subscription),
51 )
52 .route(
53 "/billing/subscriptions/manage",
54 post(manage_billing_subscription),
55 )
56 .route("/billing/monthly_spend", get(get_monthly_spend))
57}
58
59#[derive(Debug, Deserialize)]
60struct GetBillingPreferencesParams {
61 github_user_id: i32,
62}
63
64#[derive(Debug, Serialize)]
65struct BillingPreferencesResponse {
66 max_monthly_llm_usage_spending_in_cents: i32,
67}
68
69async fn get_billing_preferences(
70 Extension(app): Extension<Arc<AppState>>,
71 Query(params): Query<GetBillingPreferencesParams>,
72) -> Result<Json<BillingPreferencesResponse>> {
73 let user = app
74 .db
75 .get_user_by_github_user_id(params.github_user_id)
76 .await?
77 .ok_or_else(|| anyhow!("user not found"))?;
78
79 let preferences = app.db.get_billing_preferences(user.id).await?;
80
81 Ok(Json(BillingPreferencesResponse {
82 max_monthly_llm_usage_spending_in_cents: preferences
83 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
84 preferences.max_monthly_llm_usage_spending_in_cents
85 }),
86 }))
87}
88
89#[derive(Debug, Deserialize)]
90struct UpdateBillingPreferencesBody {
91 github_user_id: i32,
92 max_monthly_llm_usage_spending_in_cents: i32,
93}
94
95async fn update_billing_preferences(
96 Extension(app): Extension<Arc<AppState>>,
97 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
98 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
99) -> Result<Json<BillingPreferencesResponse>> {
100 let user = app
101 .db
102 .get_user_by_github_user_id(body.github_user_id)
103 .await?
104 .ok_or_else(|| anyhow!("user not found"))?;
105
106 let max_monthly_llm_usage_spending_in_cents =
107 body.max_monthly_llm_usage_spending_in_cents.max(0);
108
109 let billing_preferences =
110 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
111 app.db
112 .update_billing_preferences(
113 user.id,
114 &UpdateBillingPreferencesParams {
115 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
116 max_monthly_llm_usage_spending_in_cents,
117 ),
118 },
119 )
120 .await?
121 } else {
122 app.db
123 .create_billing_preferences(
124 user.id,
125 &crate::db::CreateBillingPreferencesParams {
126 max_monthly_llm_usage_spending_in_cents,
127 },
128 )
129 .await?
130 };
131
132 SnowflakeRow::new(
133 "Spend Limit Updated",
134 Some(user.metrics_id),
135 user.admin,
136 None,
137 json!({
138 "user_id": user.id,
139 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
140 }),
141 )
142 .write(&app.kinesis_client, &app.config.kinesis_stream)
143 .await
144 .log_err();
145
146 rpc_server.refresh_llm_tokens_for_user(user.id).await;
147
148 Ok(Json(BillingPreferencesResponse {
149 max_monthly_llm_usage_spending_in_cents: billing_preferences
150 .max_monthly_llm_usage_spending_in_cents,
151 }))
152}
153
154#[derive(Debug, Deserialize)]
155struct ListBillingSubscriptionsParams {
156 github_user_id: i32,
157}
158
159#[derive(Debug, Serialize)]
160struct BillingSubscriptionJson {
161 id: BillingSubscriptionId,
162 name: String,
163 status: StripeSubscriptionStatus,
164 trial_end_at: Option<String>,
165 cancel_at: Option<String>,
166 /// Whether this subscription can be canceled.
167 is_cancelable: bool,
168}
169
170#[derive(Debug, Serialize)]
171struct ListBillingSubscriptionsResponse {
172 subscriptions: Vec<BillingSubscriptionJson>,
173}
174
175async fn list_billing_subscriptions(
176 Extension(app): Extension<Arc<AppState>>,
177 Query(params): Query<ListBillingSubscriptionsParams>,
178) -> Result<Json<ListBillingSubscriptionsResponse>> {
179 let user = app
180 .db
181 .get_user_by_github_user_id(params.github_user_id)
182 .await?
183 .ok_or_else(|| anyhow!("user not found"))?;
184
185 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
186
187 Ok(Json(ListBillingSubscriptionsResponse {
188 subscriptions: subscriptions
189 .into_iter()
190 .map(|subscription| BillingSubscriptionJson {
191 id: subscription.id,
192 name: match subscription.kind {
193 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
194 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
195 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
196 None => "Zed LLM Usage".to_string(),
197 },
198 status: subscription.stripe_subscription_status,
199 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
200 maybe!({
201 let end_at = subscription.stripe_current_period_end?;
202 let end_at = DateTime::from_timestamp(end_at, 0)?;
203
204 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
205 })
206 } else {
207 None
208 },
209 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
210 cancel_at
211 .and_utc()
212 .to_rfc3339_opts(SecondsFormat::Millis, true)
213 }),
214 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
215 && subscription.stripe_cancel_at.is_none(),
216 })
217 .collect(),
218 }))
219}
220
221#[derive(Debug, Clone, Copy, Deserialize)]
222#[serde(rename_all = "snake_case")]
223enum ProductCode {
224 ZedPro,
225 ZedProTrial,
226}
227
228#[derive(Debug, Deserialize)]
229struct CreateBillingSubscriptionBody {
230 github_user_id: i32,
231 product: Option<ProductCode>,
232}
233
234#[derive(Debug, Serialize)]
235struct CreateBillingSubscriptionResponse {
236 checkout_session_url: String,
237}
238
239/// Initiates a Stripe Checkout session for creating a billing subscription.
240async fn create_billing_subscription(
241 Extension(app): Extension<Arc<AppState>>,
242 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
243) -> Result<Json<CreateBillingSubscriptionResponse>> {
244 let user = app
245 .db
246 .get_user_by_github_user_id(body.github_user_id)
247 .await?
248 .ok_or_else(|| anyhow!("user not found"))?;
249
250 let Some(stripe_client) = app.stripe_client.clone() else {
251 log::error!("failed to retrieve Stripe client");
252 Err(Error::http(
253 StatusCode::NOT_IMPLEMENTED,
254 "not supported".into(),
255 ))?
256 };
257 let Some(stripe_billing) = app.stripe_billing.clone() else {
258 log::error!("failed to retrieve Stripe billing object");
259 Err(Error::http(
260 StatusCode::NOT_IMPLEMENTED,
261 "not supported".into(),
262 ))?
263 };
264 let Some(llm_db) = app.llm_db.clone() else {
265 log::error!("failed to retrieve LLM database");
266 Err(Error::http(
267 StatusCode::NOT_IMPLEMENTED,
268 "not supported".into(),
269 ))?
270 };
271
272 if app.db.has_active_billing_subscription(user.id).await? {
273 return Err(Error::http(
274 StatusCode::CONFLICT,
275 "user already has an active subscription".into(),
276 ));
277 }
278
279 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
280 if let Some(existing_billing_customer) = &existing_billing_customer {
281 if existing_billing_customer.has_overdue_invoices {
282 return Err(Error::http(
283 StatusCode::PAYMENT_REQUIRED,
284 "user has overdue invoices".into(),
285 ));
286 }
287 }
288
289 let customer_id = if let Some(existing_customer) = existing_billing_customer {
290 CustomerId::from_str(&existing_customer.stripe_customer_id)
291 .context("failed to parse customer ID")?
292 } else {
293 let customer = Customer::create(
294 &stripe_client,
295 CreateCustomer {
296 email: user.email_address.as_deref(),
297 ..Default::default()
298 },
299 )
300 .await?;
301
302 customer.id
303 };
304
305 let success_url = format!(
306 "{}/account?checkout_complete=1",
307 app.config.zed_dot_dev_url()
308 );
309
310 let checkout_session_url = match body.product {
311 Some(ProductCode::ZedPro) => {
312 stripe_billing
313 .checkout_with_price(
314 app.config.zed_pro_price_id()?,
315 customer_id,
316 &user.github_login,
317 &success_url,
318 )
319 .await?
320 }
321 Some(ProductCode::ZedProTrial) => {
322 stripe_billing
323 .checkout_with_price(
324 app.config.zed_pro_trial_price_id()?,
325 customer_id,
326 &user.github_login,
327 &success_url,
328 )
329 .await?
330 }
331 None => {
332 let default_model =
333 llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
334 let stripe_model = stripe_billing.register_model(default_model).await?;
335 stripe_billing
336 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
337 .await?
338 }
339 };
340
341 Ok(Json(CreateBillingSubscriptionResponse {
342 checkout_session_url,
343 }))
344}
345
346#[derive(Debug, PartialEq, Deserialize)]
347#[serde(rename_all = "snake_case")]
348enum ManageSubscriptionIntent {
349 /// The user intends to manage their subscription.
350 ///
351 /// This will open the Stripe billing portal without putting the user in a specific flow.
352 ManageSubscription,
353 /// The user intends to upgrade to Zed Pro.
354 UpgradeToPro,
355 /// The user intends to cancel their subscription.
356 Cancel,
357 /// The user intends to stop the cancellation of their subscription.
358 StopCancellation,
359}
360
361#[derive(Debug, Deserialize)]
362struct ManageBillingSubscriptionBody {
363 github_user_id: i32,
364 intent: ManageSubscriptionIntent,
365 /// The ID of the subscription to manage.
366 subscription_id: BillingSubscriptionId,
367}
368
369#[derive(Debug, Serialize)]
370struct ManageBillingSubscriptionResponse {
371 billing_portal_session_url: Option<String>,
372}
373
374/// Initiates a Stripe customer portal session for managing a billing subscription.
375async fn manage_billing_subscription(
376 Extension(app): Extension<Arc<AppState>>,
377 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
378) -> Result<Json<ManageBillingSubscriptionResponse>> {
379 let user = app
380 .db
381 .get_user_by_github_user_id(body.github_user_id)
382 .await?
383 .ok_or_else(|| anyhow!("user not found"))?;
384
385 let Some(stripe_client) = app.stripe_client.clone() else {
386 log::error!("failed to retrieve Stripe client");
387 Err(Error::http(
388 StatusCode::NOT_IMPLEMENTED,
389 "not supported".into(),
390 ))?
391 };
392
393 let customer = app
394 .db
395 .get_billing_customer_by_user_id(user.id)
396 .await?
397 .ok_or_else(|| anyhow!("billing customer not found"))?;
398 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
399 .context("failed to parse customer ID")?;
400
401 let subscription = app
402 .db
403 .get_billing_subscription_by_id(body.subscription_id)
404 .await?
405 .ok_or_else(|| anyhow!("subscription not found"))?;
406 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
407 .context("failed to parse subscription ID")?;
408
409 if body.intent == ManageSubscriptionIntent::StopCancellation {
410 let updated_stripe_subscription = Subscription::update(
411 &stripe_client,
412 &subscription_id,
413 stripe::UpdateSubscription {
414 cancel_at_period_end: Some(false),
415 ..Default::default()
416 },
417 )
418 .await?;
419
420 app.db
421 .update_billing_subscription(
422 subscription.id,
423 &UpdateBillingSubscriptionParams {
424 stripe_cancel_at: ActiveValue::set(
425 updated_stripe_subscription
426 .cancel_at
427 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
428 .map(|time| time.naive_utc()),
429 ),
430 ..Default::default()
431 },
432 )
433 .await?;
434
435 return Ok(Json(ManageBillingSubscriptionResponse {
436 billing_portal_session_url: None,
437 }));
438 }
439
440 let flow = match body.intent {
441 ManageSubscriptionIntent::ManageSubscription => None,
442 ManageSubscriptionIntent::UpgradeToPro => {
443 let zed_pro_price_id = app.config.zed_pro_price_id()?;
444 let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id()?;
445 let zed_free_price_id = app.config.zed_free_price_id()?;
446
447 let stripe_subscription =
448 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
449
450 let subscription_item_to_update = stripe_subscription
451 .items
452 .data
453 .iter()
454 .find_map(|item| {
455 let price = item.price.as_ref()?;
456
457 if price.id == zed_free_price_id || price.id == zed_pro_trial_price_id {
458 Some(item.id.clone())
459 } else {
460 None
461 }
462 })
463 .ok_or_else(|| anyhow!("No subscription item to update"))?;
464
465 Some(CreateBillingPortalSessionFlowData {
466 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
467 subscription_update_confirm: Some(
468 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
469 subscription: subscription.stripe_subscription_id,
470 items: vec![
471 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
472 id: subscription_item_to_update.to_string(),
473 price: Some(zed_pro_price_id.to_string()),
474 quantity: Some(1),
475 },
476 ],
477 discounts: None,
478 },
479 ),
480 ..Default::default()
481 })
482 }
483 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
484 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
485 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
486 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
487 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
488 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
489 }),
490 ..Default::default()
491 }),
492 subscription_cancel: Some(
493 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
494 subscription: subscription.stripe_subscription_id,
495 retention: None,
496 },
497 ),
498 ..Default::default()
499 }),
500 ManageSubscriptionIntent::StopCancellation => unreachable!(),
501 };
502
503 let mut params = CreateBillingPortalSession::new(customer_id);
504 params.flow_data = flow;
505 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
506 params.return_url = Some(&return_url);
507
508 let session = BillingPortalSession::create(&stripe_client, params).await?;
509
510 Ok(Json(ManageBillingSubscriptionResponse {
511 billing_portal_session_url: Some(session.url),
512 }))
513}
514
515/// The amount of time we wait in between each poll of Stripe events.
516///
517/// This value should strike a balance between:
518/// 1. Being short enough that we update quickly when something in Stripe changes
519/// 2. Being long enough that we don't eat into our rate limits.
520///
521/// As a point of reference, the Sequin folks say they have this at **500ms**:
522///
523/// > We poll the Stripe /events endpoint every 500ms per account
524/// >
525/// > — https://blog.sequinstream.com/events-not-webhooks/
526const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
527
528/// The maximum number of events to return per page.
529///
530/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
531///
532/// > Limit can range between 1 and 100, and the default is 10.
533const EVENTS_LIMIT_PER_PAGE: u64 = 100;
534
535/// The number of pages consisting entirely of already-processed events that we
536/// will see before we stop retrieving events.
537///
538/// This is used to prevent over-fetching the Stripe events API for events we've
539/// already seen and processed.
540const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
541
542/// Polls the Stripe events API periodically to reconcile the records in our
543/// database with the data in Stripe.
544pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
545 let Some(stripe_client) = app.stripe_client.clone() else {
546 log::warn!("failed to retrieve Stripe client");
547 return;
548 };
549
550 let executor = app.executor.clone();
551 executor.spawn_detached({
552 let executor = executor.clone();
553 async move {
554 loop {
555 poll_stripe_events(&app, &rpc_server, &stripe_client)
556 .await
557 .log_err();
558
559 executor.sleep(POLL_EVENTS_INTERVAL).await;
560 }
561 }
562 });
563}
564
565async fn poll_stripe_events(
566 app: &Arc<AppState>,
567 rpc_server: &Arc<Server>,
568 stripe_client: &stripe::Client,
569) -> anyhow::Result<()> {
570 fn event_type_to_string(event_type: EventType) -> String {
571 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
572 // so we need to unquote it.
573 event_type.to_string().trim_matches('"').to_string()
574 }
575
576 let event_types = [
577 EventType::CustomerCreated,
578 EventType::CustomerUpdated,
579 EventType::CustomerSubscriptionCreated,
580 EventType::CustomerSubscriptionUpdated,
581 EventType::CustomerSubscriptionPaused,
582 EventType::CustomerSubscriptionResumed,
583 EventType::CustomerSubscriptionDeleted,
584 ]
585 .into_iter()
586 .map(event_type_to_string)
587 .collect::<Vec<_>>();
588
589 let mut pages_of_already_processed_events = 0;
590 let mut unprocessed_events = Vec::new();
591
592 log::info!(
593 "Stripe events: starting retrieval for {}",
594 event_types.join(", ")
595 );
596 let mut params = ListEvents::new();
597 params.types = Some(event_types.clone());
598 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
599
600 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
601 .await?
602 .paginate(params);
603
604 loop {
605 let processed_event_ids = {
606 let event_ids = event_pages
607 .page
608 .data
609 .iter()
610 .map(|event| event.id.as_str())
611 .collect::<Vec<_>>();
612 app.db
613 .get_processed_stripe_events_by_event_ids(&event_ids)
614 .await?
615 .into_iter()
616 .map(|event| event.stripe_event_id)
617 .collect::<Vec<_>>()
618 };
619
620 let mut processed_events_in_page = 0;
621 let events_in_page = event_pages.page.data.len();
622 for event in &event_pages.page.data {
623 if processed_event_ids.contains(&event.id.to_string()) {
624 processed_events_in_page += 1;
625 log::debug!("Stripe events: already processed '{}', skipping", event.id);
626 } else {
627 unprocessed_events.push(event.clone());
628 }
629 }
630
631 if processed_events_in_page == events_in_page {
632 pages_of_already_processed_events += 1;
633 }
634
635 if event_pages.page.has_more {
636 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
637 {
638 log::info!(
639 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
640 );
641 break;
642 } else {
643 log::info!("Stripe events: retrieving next page");
644 event_pages = event_pages.next(&stripe_client).await?;
645 }
646 } else {
647 break;
648 }
649 }
650
651 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
652
653 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
654 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
655
656 for event in unprocessed_events {
657 let event_id = event.id.clone();
658 let processed_event_params = CreateProcessedStripeEventParams {
659 stripe_event_id: event.id.to_string(),
660 stripe_event_type: event_type_to_string(event.type_),
661 stripe_event_created_timestamp: event.created,
662 };
663
664 // If the event has happened too far in the past, we don't want to
665 // process it and risk overwriting other more-recent updates.
666 //
667 // 1 day was chosen arbitrarily. This could be made longer or shorter.
668 let one_day = Duration::from_secs(24 * 60 * 60);
669 let a_day_ago = Utc::now() - one_day;
670 if a_day_ago.timestamp() > event.created {
671 log::info!(
672 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
673 event_id
674 );
675 app.db
676 .create_processed_stripe_event(&processed_event_params)
677 .await?;
678
679 return Ok(());
680 }
681
682 let process_result = match event.type_ {
683 EventType::CustomerCreated | EventType::CustomerUpdated => {
684 handle_customer_event(app, stripe_client, event).await
685 }
686 EventType::CustomerSubscriptionCreated
687 | EventType::CustomerSubscriptionUpdated
688 | EventType::CustomerSubscriptionPaused
689 | EventType::CustomerSubscriptionResumed
690 | EventType::CustomerSubscriptionDeleted => {
691 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
692 }
693 _ => Ok(()),
694 };
695
696 if let Some(()) = process_result
697 .with_context(|| format!("failed to process event {event_id} successfully"))
698 .log_err()
699 {
700 app.db
701 .create_processed_stripe_event(&processed_event_params)
702 .await?;
703 }
704 }
705
706 Ok(())
707}
708
709async fn handle_customer_event(
710 app: &Arc<AppState>,
711 _stripe_client: &stripe::Client,
712 event: stripe::Event,
713) -> anyhow::Result<()> {
714 let EventObject::Customer(customer) = event.data.object else {
715 bail!("unexpected event payload for {}", event.id);
716 };
717
718 log::info!("handling Stripe {} event: {}", event.type_, event.id);
719
720 let Some(email) = customer.email else {
721 log::info!("Stripe customer has no email: skipping");
722 return Ok(());
723 };
724
725 let Some(user) = app.db.get_user_by_email(&email).await? else {
726 log::info!("no user found for email: skipping");
727 return Ok(());
728 };
729
730 if let Some(existing_customer) = app
731 .db
732 .get_billing_customer_by_stripe_customer_id(&customer.id)
733 .await?
734 {
735 app.db
736 .update_billing_customer(
737 existing_customer.id,
738 &UpdateBillingCustomerParams {
739 // For now we just leave the information as-is, as it is not
740 // likely to change.
741 ..Default::default()
742 },
743 )
744 .await?;
745 } else {
746 app.db
747 .create_billing_customer(&CreateBillingCustomerParams {
748 user_id: user.id,
749 stripe_customer_id: customer.id.to_string(),
750 })
751 .await?;
752 }
753
754 Ok(())
755}
756
757async fn handle_customer_subscription_event(
758 app: &Arc<AppState>,
759 rpc_server: &Arc<Server>,
760 stripe_client: &stripe::Client,
761 event: stripe::Event,
762) -> anyhow::Result<()> {
763 let EventObject::Subscription(subscription) = event.data.object else {
764 bail!("unexpected event payload for {}", event.id);
765 };
766
767 log::info!("handling Stripe {} event: {}", event.type_, event.id);
768
769 let subscription_kind = maybe!({
770 let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
771 let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id().ok()?;
772 let zed_free_price_id = app.config.zed_free_price_id().ok()?;
773
774 subscription.items.data.iter().find_map(|item| {
775 let price = item.price.as_ref()?;
776
777 if price.id == zed_pro_price_id {
778 Some(SubscriptionKind::ZedPro)
779 } else if price.id == zed_pro_trial_price_id {
780 Some(SubscriptionKind::ZedProTrial)
781 } else if price.id == zed_free_price_id {
782 Some(SubscriptionKind::ZedFree)
783 } else {
784 None
785 }
786 })
787 });
788
789 let billing_customer =
790 find_or_create_billing_customer(app, stripe_client, subscription.customer)
791 .await?
792 .ok_or_else(|| anyhow!("billing customer not found"))?;
793
794 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
795 && subscription
796 .cancellation_details
797 .as_ref()
798 .and_then(|details| details.reason)
799 .map_or(false, |reason| {
800 reason == CancellationDetailsReason::PaymentFailed
801 });
802
803 if was_canceled_due_to_payment_failure {
804 app.db
805 .update_billing_customer(
806 billing_customer.id,
807 &UpdateBillingCustomerParams {
808 has_overdue_invoices: ActiveValue::set(true),
809 ..Default::default()
810 },
811 )
812 .await?;
813 }
814
815 if let Some(existing_subscription) = app
816 .db
817 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
818 .await?
819 {
820 app.db
821 .update_billing_subscription(
822 existing_subscription.id,
823 &UpdateBillingSubscriptionParams {
824 billing_customer_id: ActiveValue::set(billing_customer.id),
825 kind: ActiveValue::set(subscription_kind),
826 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
827 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
828 stripe_cancel_at: ActiveValue::set(
829 subscription
830 .cancel_at
831 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
832 .map(|time| time.naive_utc()),
833 ),
834 stripe_cancellation_reason: ActiveValue::set(
835 subscription
836 .cancellation_details
837 .and_then(|details| details.reason)
838 .map(|reason| reason.into()),
839 ),
840 stripe_current_period_start: ActiveValue::set(Some(
841 subscription.current_period_start,
842 )),
843 stripe_current_period_end: ActiveValue::set(Some(
844 subscription.current_period_end,
845 )),
846 },
847 )
848 .await?;
849 } else {
850 // If the user already has an active billing subscription, ignore the
851 // event and return an `Ok` to signal that it was processed
852 // successfully.
853 //
854 // There is the possibility that this could cause us to not create a
855 // subscription in the following scenario:
856 //
857 // 1. User has an active subscription A
858 // 2. User cancels subscription A
859 // 3. User creates a new subscription B
860 // 4. We process the new subscription B before the cancellation of subscription A
861 // 5. User ends up with no subscriptions
862 //
863 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
864 if app
865 .db
866 .has_active_billing_subscription(billing_customer.user_id)
867 .await?
868 {
869 log::info!(
870 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
871 user_id = billing_customer.user_id,
872 subscription_id = subscription.id
873 );
874 return Ok(());
875 }
876
877 app.db
878 .create_billing_subscription(&CreateBillingSubscriptionParams {
879 billing_customer_id: billing_customer.id,
880 kind: subscription_kind,
881 stripe_subscription_id: subscription.id.to_string(),
882 stripe_subscription_status: subscription.status.into(),
883 stripe_cancellation_reason: subscription
884 .cancellation_details
885 .and_then(|details| details.reason)
886 .map(|reason| reason.into()),
887 stripe_current_period_start: Some(subscription.current_period_start),
888 stripe_current_period_end: Some(subscription.current_period_end),
889 })
890 .await?;
891 }
892
893 // When the user's subscription changes, we want to refresh their LLM tokens
894 // to either grant/revoke access.
895 rpc_server
896 .refresh_llm_tokens_for_user(billing_customer.user_id)
897 .await;
898
899 Ok(())
900}
901
902#[derive(Debug, Deserialize)]
903struct GetMonthlySpendParams {
904 github_user_id: i32,
905}
906
907#[derive(Debug, Serialize)]
908struct GetMonthlySpendResponse {
909 monthly_free_tier_spend_in_cents: u32,
910 monthly_free_tier_allowance_in_cents: u32,
911 monthly_spend_in_cents: u32,
912}
913
914async fn get_monthly_spend(
915 Extension(app): Extension<Arc<AppState>>,
916 Query(params): Query<GetMonthlySpendParams>,
917) -> Result<Json<GetMonthlySpendResponse>> {
918 let user = app
919 .db
920 .get_user_by_github_user_id(params.github_user_id)
921 .await?
922 .ok_or_else(|| anyhow!("user not found"))?;
923
924 let Some(llm_db) = app.llm_db.clone() else {
925 return Err(Error::http(
926 StatusCode::NOT_IMPLEMENTED,
927 "LLM database not available".into(),
928 ));
929 };
930
931 let free_tier = user
932 .custom_llm_monthly_allowance_in_cents
933 .map(|allowance| Cents(allowance as u32))
934 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
935
936 let spending_for_month = llm_db
937 .get_user_spending_for_month(user.id, Utc::now())
938 .await?;
939
940 let free_tier_spend = Cents::min(spending_for_month, free_tier);
941 let monthly_spend = spending_for_month.saturating_sub(free_tier);
942
943 Ok(Json(GetMonthlySpendResponse {
944 monthly_free_tier_spend_in_cents: free_tier_spend.0,
945 monthly_free_tier_allowance_in_cents: free_tier.0,
946 monthly_spend_in_cents: monthly_spend.0,
947 }))
948}
949
950impl From<SubscriptionStatus> for StripeSubscriptionStatus {
951 fn from(value: SubscriptionStatus) -> Self {
952 match value {
953 SubscriptionStatus::Incomplete => Self::Incomplete,
954 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
955 SubscriptionStatus::Trialing => Self::Trialing,
956 SubscriptionStatus::Active => Self::Active,
957 SubscriptionStatus::PastDue => Self::PastDue,
958 SubscriptionStatus::Canceled => Self::Canceled,
959 SubscriptionStatus::Unpaid => Self::Unpaid,
960 SubscriptionStatus::Paused => Self::Paused,
961 }
962 }
963}
964
965impl From<CancellationDetailsReason> for StripeCancellationReason {
966 fn from(value: CancellationDetailsReason) -> Self {
967 match value {
968 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
969 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
970 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
971 }
972 }
973}
974
975/// Finds or creates a billing customer using the provided customer.
976async fn find_or_create_billing_customer(
977 app: &Arc<AppState>,
978 stripe_client: &stripe::Client,
979 customer_or_id: Expandable<Customer>,
980) -> anyhow::Result<Option<billing_customer::Model>> {
981 let customer_id = match &customer_or_id {
982 Expandable::Id(id) => id,
983 Expandable::Object(customer) => customer.id.as_ref(),
984 };
985
986 // If we already have a billing customer record associated with the Stripe customer,
987 // there's nothing more we need to do.
988 if let Some(billing_customer) = app
989 .db
990 .get_billing_customer_by_stripe_customer_id(customer_id)
991 .await?
992 {
993 return Ok(Some(billing_customer));
994 }
995
996 // If all we have is a customer ID, resolve it to a full customer record by
997 // hitting the Stripe API.
998 let customer = match customer_or_id {
999 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1000 Expandable::Object(customer) => *customer,
1001 };
1002
1003 let Some(email) = customer.email else {
1004 return Ok(None);
1005 };
1006
1007 let Some(user) = app.db.get_user_by_email(&email).await? else {
1008 return Ok(None);
1009 };
1010
1011 let billing_customer = app
1012 .db
1013 .create_billing_customer(&CreateBillingCustomerParams {
1014 user_id: user.id,
1015 stripe_customer_id: customer.id.to_string(),
1016 })
1017 .await?;
1018
1019 Ok(Some(billing_customer))
1020}
1021
1022const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1023
1024pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
1025 let Some(stripe_billing) = app.stripe_billing.clone() else {
1026 log::warn!("failed to retrieve Stripe billing object");
1027 return;
1028 };
1029 let Some(llm_db) = app.llm_db.clone() else {
1030 log::warn!("failed to retrieve LLM database");
1031 return;
1032 };
1033
1034 let executor = app.executor.clone();
1035 executor.spawn_detached({
1036 let executor = executor.clone();
1037 async move {
1038 loop {
1039 sync_with_stripe(&app, &llm_db, &stripe_billing)
1040 .await
1041 .context("failed to sync LLM usage to Stripe")
1042 .trace_err();
1043 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
1044 }
1045 }
1046 });
1047}
1048
1049async fn sync_with_stripe(
1050 app: &Arc<AppState>,
1051 llm_db: &Arc<LlmDatabase>,
1052 stripe_billing: &Arc<StripeBilling>,
1053) -> anyhow::Result<()> {
1054 let events = llm_db.get_billing_events().await?;
1055 let user_ids = events
1056 .iter()
1057 .map(|(event, _)| event.user_id)
1058 .collect::<HashSet<UserId>>();
1059 let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1060
1061 for (event, model) in events {
1062 let Some((stripe_db_customer, stripe_db_subscription)) =
1063 stripe_subscriptions.get(&event.user_id)
1064 else {
1065 tracing::warn!(
1066 user_id = event.user_id.0,
1067 "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1068 );
1069 continue;
1070 };
1071 let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1072 .stripe_subscription_id
1073 .parse()
1074 .context("failed to parse stripe subscription id from db")?;
1075 let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1076 .stripe_customer_id
1077 .parse()
1078 .context("failed to parse stripe customer id from db")?;
1079
1080 let stripe_model = stripe_billing.register_model(&model).await?;
1081 stripe_billing
1082 .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1083 .await?;
1084 stripe_billing
1085 .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
1086 .await?;
1087 llm_db.consume_billing_event(event.id).await?;
1088 }
1089
1090 Ok(())
1091}