1use anyhow::{Context as _, bail};
2use axum::{Extension, Json, Router, extract, routing::post};
3use chrono::{DateTime, Utc};
4use collections::{HashMap, HashSet};
5use reqwest::StatusCode;
6use sea_orm::ActiveValue;
7use serde::{Deserialize, Serialize};
8use std::{str::FromStr, sync::Arc, time::Duration};
9use stripe::{
10 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
11 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
12 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
13 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
14 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
15 CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
16 PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
17};
18use util::{ResultExt, maybe};
19use zed_llm_client::LanguageModelProvider;
20
21use crate::db::billing_subscription::{
22 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
23};
24use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
25use crate::rpc::{ResultExt as _, Server};
26use crate::stripe_client::{
27 StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
28 StripeSubscriptionId,
29};
30use crate::{AppState, Error, Result};
31use crate::{db::UserId, llm::db::LlmDatabase};
32use crate::{
33 db::{
34 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
35 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
36 UpdateBillingSubscriptionParams, billing_customer,
37 },
38 stripe_billing::StripeBilling,
39};
40
41pub fn router() -> Router {
42 Router::new()
43 .route(
44 "/billing/subscriptions/manage",
45 post(manage_billing_subscription),
46 )
47 .route(
48 "/billing/subscriptions/sync",
49 post(sync_billing_subscription),
50 )
51}
52
53#[derive(Debug, PartialEq, Deserialize)]
54#[serde(rename_all = "snake_case")]
55enum ManageSubscriptionIntent {
56 /// The user intends to manage their subscription.
57 ///
58 /// This will open the Stripe billing portal without putting the user in a specific flow.
59 ManageSubscription,
60 /// The user intends to update their payment method.
61 UpdatePaymentMethod,
62 /// The user intends to upgrade to Zed Pro.
63 UpgradeToPro,
64 /// The user intends to cancel their subscription.
65 Cancel,
66 /// The user intends to stop the cancellation of their subscription.
67 StopCancellation,
68}
69
70#[derive(Debug, Deserialize)]
71struct ManageBillingSubscriptionBody {
72 github_user_id: i32,
73 intent: ManageSubscriptionIntent,
74 /// The ID of the subscription to manage.
75 subscription_id: BillingSubscriptionId,
76 redirect_to: Option<String>,
77}
78
79#[derive(Debug, Serialize)]
80struct ManageBillingSubscriptionResponse {
81 billing_portal_session_url: Option<String>,
82}
83
84/// Initiates a Stripe customer portal session for managing a billing subscription.
85async fn manage_billing_subscription(
86 Extension(app): Extension<Arc<AppState>>,
87 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
88) -> Result<Json<ManageBillingSubscriptionResponse>> {
89 let user = app
90 .db
91 .get_user_by_github_user_id(body.github_user_id)
92 .await?
93 .context("user not found")?;
94
95 let Some(stripe_client) = app.real_stripe_client.clone() else {
96 log::error!("failed to retrieve Stripe client");
97 Err(Error::http(
98 StatusCode::NOT_IMPLEMENTED,
99 "not supported".into(),
100 ))?
101 };
102
103 let Some(stripe_billing) = app.stripe_billing.clone() else {
104 log::error!("failed to retrieve Stripe billing object");
105 Err(Error::http(
106 StatusCode::NOT_IMPLEMENTED,
107 "not supported".into(),
108 ))?
109 };
110
111 let customer = app
112 .db
113 .get_billing_customer_by_user_id(user.id)
114 .await?
115 .context("billing customer not found")?;
116 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
117 .context("failed to parse customer ID")?;
118
119 let subscription = app
120 .db
121 .get_billing_subscription_by_id(body.subscription_id)
122 .await?
123 .context("subscription not found")?;
124 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
125 .context("failed to parse subscription ID")?;
126
127 if body.intent == ManageSubscriptionIntent::StopCancellation {
128 let updated_stripe_subscription = Subscription::update(
129 &stripe_client,
130 &subscription_id,
131 stripe::UpdateSubscription {
132 cancel_at_period_end: Some(false),
133 ..Default::default()
134 },
135 )
136 .await?;
137
138 app.db
139 .update_billing_subscription(
140 subscription.id,
141 &UpdateBillingSubscriptionParams {
142 stripe_cancel_at: ActiveValue::set(
143 updated_stripe_subscription
144 .cancel_at
145 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
146 .map(|time| time.naive_utc()),
147 ),
148 ..Default::default()
149 },
150 )
151 .await?;
152
153 return Ok(Json(ManageBillingSubscriptionResponse {
154 billing_portal_session_url: None,
155 }));
156 }
157
158 let flow = match body.intent {
159 ManageSubscriptionIntent::ManageSubscription => None,
160 ManageSubscriptionIntent::UpgradeToPro => {
161 let zed_pro_price_id: stripe::PriceId =
162 stripe_billing.zed_pro_price_id().await?.try_into()?;
163 let zed_free_price_id: stripe::PriceId =
164 stripe_billing.zed_free_price_id().await?.try_into()?;
165
166 let stripe_subscription =
167 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
168
169 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
170 && stripe_subscription.items.data.iter().any(|item| {
171 item.price
172 .as_ref()
173 .map_or(false, |price| price.id == zed_pro_price_id)
174 });
175 if is_on_zed_pro_trial {
176 let payment_methods = PaymentMethod::list(
177 &stripe_client,
178 &stripe::ListPaymentMethods {
179 customer: Some(stripe_subscription.customer.id()),
180 ..Default::default()
181 },
182 )
183 .await?;
184
185 let has_payment_method = !payment_methods.data.is_empty();
186 if !has_payment_method {
187 return Err(Error::http(
188 StatusCode::BAD_REQUEST,
189 "missing payment method".into(),
190 ));
191 }
192
193 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
194 Subscription::update(
195 &stripe_client,
196 &stripe_subscription.id,
197 stripe::UpdateSubscription {
198 trial_end: Some(stripe::Scheduled::now()),
199 ..Default::default()
200 },
201 )
202 .await?;
203
204 return Ok(Json(ManageBillingSubscriptionResponse {
205 billing_portal_session_url: None,
206 }));
207 }
208
209 let subscription_item_to_update = stripe_subscription
210 .items
211 .data
212 .iter()
213 .find_map(|item| {
214 let price = item.price.as_ref()?;
215
216 if price.id == zed_free_price_id {
217 Some(item.id.clone())
218 } else {
219 None
220 }
221 })
222 .context("No subscription item to update")?;
223
224 Some(CreateBillingPortalSessionFlowData {
225 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
226 subscription_update_confirm: Some(
227 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
228 subscription: subscription.stripe_subscription_id,
229 items: vec![
230 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
231 id: subscription_item_to_update.to_string(),
232 price: Some(zed_pro_price_id.to_string()),
233 quantity: Some(1),
234 },
235 ],
236 discounts: None,
237 },
238 ),
239 ..Default::default()
240 })
241 }
242 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
243 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
244 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
245 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
246 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
247 return_url: format!(
248 "{}{path}",
249 app.config.zed_dot_dev_url(),
250 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
251 ),
252 }),
253 ..Default::default()
254 }),
255 ..Default::default()
256 }),
257 ManageSubscriptionIntent::Cancel => {
258 if subscription.kind == Some(SubscriptionKind::ZedFree) {
259 return Err(Error::http(
260 StatusCode::BAD_REQUEST,
261 "free subscription cannot be canceled".into(),
262 ));
263 }
264
265 Some(CreateBillingPortalSessionFlowData {
266 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
267 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
268 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
269 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
270 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
271 }),
272 ..Default::default()
273 }),
274 subscription_cancel: Some(
275 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
276 subscription: subscription.stripe_subscription_id,
277 retention: None,
278 },
279 ),
280 ..Default::default()
281 })
282 }
283 ManageSubscriptionIntent::StopCancellation => unreachable!(),
284 };
285
286 let mut params = CreateBillingPortalSession::new(customer_id);
287 params.flow_data = flow;
288 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
289 params.return_url = Some(&return_url);
290
291 let session = BillingPortalSession::create(&stripe_client, params).await?;
292
293 Ok(Json(ManageBillingSubscriptionResponse {
294 billing_portal_session_url: Some(session.url),
295 }))
296}
297
298#[derive(Debug, Deserialize)]
299struct SyncBillingSubscriptionBody {
300 github_user_id: i32,
301}
302
303#[derive(Debug, Serialize)]
304struct SyncBillingSubscriptionResponse {
305 stripe_customer_id: String,
306}
307
308async fn sync_billing_subscription(
309 Extension(app): Extension<Arc<AppState>>,
310 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
311) -> Result<Json<SyncBillingSubscriptionResponse>> {
312 let Some(stripe_client) = app.stripe_client.clone() else {
313 log::error!("failed to retrieve Stripe client");
314 Err(Error::http(
315 StatusCode::NOT_IMPLEMENTED,
316 "not supported".into(),
317 ))?
318 };
319
320 let user = app
321 .db
322 .get_user_by_github_user_id(body.github_user_id)
323 .await?
324 .context("user not found")?;
325
326 let billing_customer = app
327 .db
328 .get_billing_customer_by_user_id(user.id)
329 .await?
330 .context("billing customer not found")?;
331 let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
332
333 let subscriptions = stripe_client
334 .list_subscriptions_for_customer(&stripe_customer_id)
335 .await?;
336
337 for subscription in subscriptions {
338 let subscription_id = subscription.id.clone();
339
340 sync_subscription(&app, &stripe_client, subscription)
341 .await
342 .with_context(|| {
343 format!(
344 "failed to sync subscription {subscription_id} for user {}",
345 user.id,
346 )
347 })?;
348 }
349
350 Ok(Json(SyncBillingSubscriptionResponse {
351 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
352 }))
353}
354
355/// The amount of time we wait in between each poll of Stripe events.
356///
357/// This value should strike a balance between:
358/// 1. Being short enough that we update quickly when something in Stripe changes
359/// 2. Being long enough that we don't eat into our rate limits.
360///
361/// As a point of reference, the Sequin folks say they have this at **500ms**:
362///
363/// > We poll the Stripe /events endpoint every 500ms per account
364/// >
365/// > — https://blog.sequinstream.com/events-not-webhooks/
366const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
367
368/// The maximum number of events to return per page.
369///
370/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
371///
372/// > Limit can range between 1 and 100, and the default is 10.
373const EVENTS_LIMIT_PER_PAGE: u64 = 100;
374
375/// The number of pages consisting entirely of already-processed events that we
376/// will see before we stop retrieving events.
377///
378/// This is used to prevent over-fetching the Stripe events API for events we've
379/// already seen and processed.
380const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
381
382/// Polls the Stripe events API periodically to reconcile the records in our
383/// database with the data in Stripe.
384pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
385 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
386 log::warn!("failed to retrieve Stripe client");
387 return;
388 };
389 let Some(stripe_client) = app.stripe_client.clone() else {
390 log::warn!("failed to retrieve Stripe client");
391 return;
392 };
393
394 let executor = app.executor.clone();
395 executor.spawn_detached({
396 let executor = executor.clone();
397 async move {
398 loop {
399 poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
400 .await
401 .log_err();
402
403 executor.sleep(POLL_EVENTS_INTERVAL).await;
404 }
405 }
406 });
407}
408
409async fn poll_stripe_events(
410 app: &Arc<AppState>,
411 rpc_server: &Arc<Server>,
412 stripe_client: &Arc<dyn StripeClient>,
413 real_stripe_client: &stripe::Client,
414) -> anyhow::Result<()> {
415 fn event_type_to_string(event_type: EventType) -> String {
416 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
417 // so we need to unquote it.
418 event_type.to_string().trim_matches('"').to_string()
419 }
420
421 let event_types = [
422 EventType::CustomerCreated,
423 EventType::CustomerUpdated,
424 EventType::CustomerSubscriptionCreated,
425 EventType::CustomerSubscriptionUpdated,
426 EventType::CustomerSubscriptionPaused,
427 EventType::CustomerSubscriptionResumed,
428 EventType::CustomerSubscriptionDeleted,
429 ]
430 .into_iter()
431 .map(event_type_to_string)
432 .collect::<Vec<_>>();
433
434 let mut pages_of_already_processed_events = 0;
435 let mut unprocessed_events = Vec::new();
436
437 log::info!(
438 "Stripe events: starting retrieval for {}",
439 event_types.join(", ")
440 );
441 let mut params = ListEvents::new();
442 params.types = Some(event_types.clone());
443 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
444
445 let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
446 .await?
447 .paginate(params);
448
449 loop {
450 let processed_event_ids = {
451 let event_ids = event_pages
452 .page
453 .data
454 .iter()
455 .map(|event| event.id.as_str())
456 .collect::<Vec<_>>();
457 app.db
458 .get_processed_stripe_events_by_event_ids(&event_ids)
459 .await?
460 .into_iter()
461 .map(|event| event.stripe_event_id)
462 .collect::<Vec<_>>()
463 };
464
465 let mut processed_events_in_page = 0;
466 let events_in_page = event_pages.page.data.len();
467 for event in &event_pages.page.data {
468 if processed_event_ids.contains(&event.id.to_string()) {
469 processed_events_in_page += 1;
470 log::debug!("Stripe events: already processed '{}', skipping", event.id);
471 } else {
472 unprocessed_events.push(event.clone());
473 }
474 }
475
476 if processed_events_in_page == events_in_page {
477 pages_of_already_processed_events += 1;
478 }
479
480 if event_pages.page.has_more {
481 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
482 {
483 log::info!(
484 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
485 );
486 break;
487 } else {
488 log::info!("Stripe events: retrieving next page");
489 event_pages = event_pages.next(&real_stripe_client).await?;
490 }
491 } else {
492 break;
493 }
494 }
495
496 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
497
498 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
499 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
500
501 for event in unprocessed_events {
502 let event_id = event.id.clone();
503 let processed_event_params = CreateProcessedStripeEventParams {
504 stripe_event_id: event.id.to_string(),
505 stripe_event_type: event_type_to_string(event.type_),
506 stripe_event_created_timestamp: event.created,
507 };
508
509 // If the event has happened too far in the past, we don't want to
510 // process it and risk overwriting other more-recent updates.
511 //
512 // 1 day was chosen arbitrarily. This could be made longer or shorter.
513 let one_day = Duration::from_secs(24 * 60 * 60);
514 let a_day_ago = Utc::now() - one_day;
515 if a_day_ago.timestamp() > event.created {
516 log::info!(
517 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
518 event_id
519 );
520 app.db
521 .create_processed_stripe_event(&processed_event_params)
522 .await?;
523
524 continue;
525 }
526
527 let process_result = match event.type_ {
528 EventType::CustomerCreated | EventType::CustomerUpdated => {
529 handle_customer_event(app, real_stripe_client, event).await
530 }
531 EventType::CustomerSubscriptionCreated
532 | EventType::CustomerSubscriptionUpdated
533 | EventType::CustomerSubscriptionPaused
534 | EventType::CustomerSubscriptionResumed
535 | EventType::CustomerSubscriptionDeleted => {
536 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
537 }
538 _ => Ok(()),
539 };
540
541 if let Some(()) = process_result
542 .with_context(|| format!("failed to process event {event_id} successfully"))
543 .log_err()
544 {
545 app.db
546 .create_processed_stripe_event(&processed_event_params)
547 .await?;
548 }
549 }
550
551 Ok(())
552}
553
554async fn handle_customer_event(
555 app: &Arc<AppState>,
556 _stripe_client: &stripe::Client,
557 event: stripe::Event,
558) -> anyhow::Result<()> {
559 let EventObject::Customer(customer) = event.data.object else {
560 bail!("unexpected event payload for {}", event.id);
561 };
562
563 log::info!("handling Stripe {} event: {}", event.type_, event.id);
564
565 let Some(email) = customer.email else {
566 log::info!("Stripe customer has no email: skipping");
567 return Ok(());
568 };
569
570 let Some(user) = app.db.get_user_by_email(&email).await? else {
571 log::info!("no user found for email: skipping");
572 return Ok(());
573 };
574
575 if let Some(existing_customer) = app
576 .db
577 .get_billing_customer_by_stripe_customer_id(&customer.id)
578 .await?
579 {
580 app.db
581 .update_billing_customer(
582 existing_customer.id,
583 &UpdateBillingCustomerParams {
584 // For now we just leave the information as-is, as it is not
585 // likely to change.
586 ..Default::default()
587 },
588 )
589 .await?;
590 } else {
591 app.db
592 .create_billing_customer(&CreateBillingCustomerParams {
593 user_id: user.id,
594 stripe_customer_id: customer.id.to_string(),
595 })
596 .await?;
597 }
598
599 Ok(())
600}
601
602async fn sync_subscription(
603 app: &Arc<AppState>,
604 stripe_client: &Arc<dyn StripeClient>,
605 subscription: StripeSubscription,
606) -> anyhow::Result<billing_customer::Model> {
607 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
608 stripe_billing
609 .determine_subscription_kind(&subscription)
610 .await
611 } else {
612 None
613 };
614
615 let billing_customer =
616 find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
617 .await?
618 .context("billing customer not found")?;
619
620 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
621 if subscription.status == SubscriptionStatus::Trialing {
622 let current_period_start =
623 DateTime::from_timestamp(subscription.current_period_start, 0)
624 .context("No trial subscription period start")?;
625
626 app.db
627 .update_billing_customer(
628 billing_customer.id,
629 &UpdateBillingCustomerParams {
630 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
631 ..Default::default()
632 },
633 )
634 .await?;
635 }
636 }
637
638 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
639 && subscription
640 .cancellation_details
641 .as_ref()
642 .and_then(|details| details.reason)
643 .map_or(false, |reason| {
644 reason == StripeCancellationDetailsReason::PaymentFailed
645 });
646
647 if was_canceled_due_to_payment_failure {
648 app.db
649 .update_billing_customer(
650 billing_customer.id,
651 &UpdateBillingCustomerParams {
652 has_overdue_invoices: ActiveValue::set(true),
653 ..Default::default()
654 },
655 )
656 .await?;
657 }
658
659 if let Some(existing_subscription) = app
660 .db
661 .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
662 .await?
663 {
664 app.db
665 .update_billing_subscription(
666 existing_subscription.id,
667 &UpdateBillingSubscriptionParams {
668 billing_customer_id: ActiveValue::set(billing_customer.id),
669 kind: ActiveValue::set(subscription_kind),
670 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
671 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
672 stripe_cancel_at: ActiveValue::set(
673 subscription
674 .cancel_at
675 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
676 .map(|time| time.naive_utc()),
677 ),
678 stripe_cancellation_reason: ActiveValue::set(
679 subscription
680 .cancellation_details
681 .and_then(|details| details.reason)
682 .map(|reason| reason.into()),
683 ),
684 stripe_current_period_start: ActiveValue::set(Some(
685 subscription.current_period_start,
686 )),
687 stripe_current_period_end: ActiveValue::set(Some(
688 subscription.current_period_end,
689 )),
690 },
691 )
692 .await?;
693 } else {
694 if let Some(existing_subscription) = app
695 .db
696 .get_active_billing_subscription(billing_customer.user_id)
697 .await?
698 {
699 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
700 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
701 {
702 let stripe_subscription_id = StripeSubscriptionId(
703 existing_subscription.stripe_subscription_id.clone().into(),
704 );
705
706 stripe_client
707 .cancel_subscription(&stripe_subscription_id)
708 .await?;
709 } else {
710 // If the user already has an active billing subscription, ignore the
711 // event and return an `Ok` to signal that it was processed
712 // successfully.
713 //
714 // There is the possibility that this could cause us to not create a
715 // subscription in the following scenario:
716 //
717 // 1. User has an active subscription A
718 // 2. User cancels subscription A
719 // 3. User creates a new subscription B
720 // 4. We process the new subscription B before the cancellation of subscription A
721 // 5. User ends up with no subscriptions
722 //
723 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
724
725 log::info!(
726 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
727 user_id = billing_customer.user_id,
728 subscription_id = subscription.id
729 );
730 return Ok(billing_customer);
731 }
732 }
733
734 app.db
735 .create_billing_subscription(&CreateBillingSubscriptionParams {
736 billing_customer_id: billing_customer.id,
737 kind: subscription_kind,
738 stripe_subscription_id: subscription.id.to_string(),
739 stripe_subscription_status: subscription.status.into(),
740 stripe_cancellation_reason: subscription
741 .cancellation_details
742 .and_then(|details| details.reason)
743 .map(|reason| reason.into()),
744 stripe_current_period_start: Some(subscription.current_period_start),
745 stripe_current_period_end: Some(subscription.current_period_end),
746 })
747 .await?;
748 }
749
750 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
751 if subscription.status == SubscriptionStatus::Canceled
752 || subscription.status == SubscriptionStatus::Paused
753 {
754 let already_has_active_billing_subscription = app
755 .db
756 .has_active_billing_subscription(billing_customer.user_id)
757 .await?;
758 if !already_has_active_billing_subscription {
759 let stripe_customer_id =
760 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
761
762 stripe_billing
763 .subscribe_to_zed_free(stripe_customer_id)
764 .await?;
765 }
766 }
767 }
768
769 Ok(billing_customer)
770}
771
772async fn handle_customer_subscription_event(
773 app: &Arc<AppState>,
774 rpc_server: &Arc<Server>,
775 stripe_client: &Arc<dyn StripeClient>,
776 event: stripe::Event,
777) -> anyhow::Result<()> {
778 let EventObject::Subscription(subscription) = event.data.object else {
779 bail!("unexpected event payload for {}", event.id);
780 };
781
782 log::info!("handling Stripe {} event: {}", event.type_, event.id);
783
784 let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
785
786 // When the user's subscription changes, push down any changes to their plan.
787 rpc_server
788 .update_plan_for_user_legacy(billing_customer.user_id)
789 .await
790 .trace_err();
791
792 // When the user's subscription changes, we want to refresh their LLM tokens
793 // to either grant/revoke access.
794 rpc_server
795 .refresh_llm_tokens_for_user(billing_customer.user_id)
796 .await;
797
798 Ok(())
799}
800
801impl From<SubscriptionStatus> for StripeSubscriptionStatus {
802 fn from(value: SubscriptionStatus) -> Self {
803 match value {
804 SubscriptionStatus::Incomplete => Self::Incomplete,
805 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
806 SubscriptionStatus::Trialing => Self::Trialing,
807 SubscriptionStatus::Active => Self::Active,
808 SubscriptionStatus::PastDue => Self::PastDue,
809 SubscriptionStatus::Canceled => Self::Canceled,
810 SubscriptionStatus::Unpaid => Self::Unpaid,
811 SubscriptionStatus::Paused => Self::Paused,
812 }
813 }
814}
815
816impl From<CancellationDetailsReason> for StripeCancellationReason {
817 fn from(value: CancellationDetailsReason) -> Self {
818 match value {
819 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
820 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
821 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
822 }
823 }
824}
825
826/// Finds or creates a billing customer using the provided customer.
827pub async fn find_or_create_billing_customer(
828 app: &Arc<AppState>,
829 stripe_client: &dyn StripeClient,
830 customer_id: &StripeCustomerId,
831) -> anyhow::Result<Option<billing_customer::Model>> {
832 // If we already have a billing customer record associated with the Stripe customer,
833 // there's nothing more we need to do.
834 if let Some(billing_customer) = app
835 .db
836 .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
837 .await?
838 {
839 return Ok(Some(billing_customer));
840 }
841
842 let customer = stripe_client.get_customer(customer_id).await?;
843
844 let Some(email) = customer.email else {
845 return Ok(None);
846 };
847
848 let Some(user) = app.db.get_user_by_email(&email).await? else {
849 return Ok(None);
850 };
851
852 let billing_customer = app
853 .db
854 .create_billing_customer(&CreateBillingCustomerParams {
855 user_id: user.id,
856 stripe_customer_id: customer.id.to_string(),
857 })
858 .await?;
859
860 Ok(Some(billing_customer))
861}
862
863const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
864
865pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
866 let Some(stripe_billing) = app.stripe_billing.clone() else {
867 log::warn!("failed to retrieve Stripe billing object");
868 return;
869 };
870 let Some(llm_db) = app.llm_db.clone() else {
871 log::warn!("failed to retrieve LLM database");
872 return;
873 };
874
875 let executor = app.executor.clone();
876 executor.spawn_detached({
877 let executor = executor.clone();
878 async move {
879 loop {
880 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
881 .await
882 .context("failed to sync LLM request usage to Stripe")
883 .trace_err();
884 executor
885 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
886 .await;
887 }
888 }
889 });
890}
891
892async fn sync_model_request_usage_with_stripe(
893 app: &Arc<AppState>,
894 llm_db: &Arc<LlmDatabase>,
895 stripe_billing: &Arc<StripeBilling>,
896) -> anyhow::Result<()> {
897 log::info!("Stripe usage sync: Starting");
898 let started_at = Utc::now();
899
900 let staff_users = app.db.get_staff_users().await?;
901 let staff_user_ids = staff_users
902 .iter()
903 .map(|user| user.id)
904 .collect::<HashSet<UserId>>();
905
906 let usage_meters = llm_db
907 .get_current_subscription_usage_meters(Utc::now())
908 .await?;
909 let mut usage_meters_by_user_id =
910 HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
911 for (usage_meter, usage) in usage_meters {
912 let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
913 meters.push(usage_meter);
914 }
915
916 log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
917 let get_zed_pro_subscriptions_started_at = Utc::now();
918 let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
919 log::info!(
920 "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
921 billing_subscriptions.len(),
922 Utc::now() - get_zed_pro_subscriptions_started_at
923 );
924
925 let claude_sonnet_4 = stripe_billing
926 .find_price_by_lookup_key("claude-sonnet-4-requests")
927 .await?;
928 let claude_sonnet_4_max = stripe_billing
929 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
930 .await?;
931 let claude_opus_4 = stripe_billing
932 .find_price_by_lookup_key("claude-opus-4-requests")
933 .await?;
934 let claude_opus_4_max = stripe_billing
935 .find_price_by_lookup_key("claude-opus-4-requests-max")
936 .await?;
937 let claude_3_5_sonnet = stripe_billing
938 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
939 .await?;
940 let claude_3_7_sonnet = stripe_billing
941 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
942 .await?;
943 let claude_3_7_sonnet_max = stripe_billing
944 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
945 .await?;
946
947 let model_mode_combinations = [
948 ("claude-opus-4", CompletionMode::Max),
949 ("claude-opus-4", CompletionMode::Normal),
950 ("claude-sonnet-4", CompletionMode::Max),
951 ("claude-sonnet-4", CompletionMode::Normal),
952 ("claude-3-7-sonnet", CompletionMode::Max),
953 ("claude-3-7-sonnet", CompletionMode::Normal),
954 ("claude-3-5-sonnet", CompletionMode::Normal),
955 ];
956
957 let billing_subscription_count = billing_subscriptions.len();
958
959 log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
960
961 for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
962 maybe!(async {
963 if staff_user_ids.contains(&user_id) {
964 return anyhow::Ok(());
965 }
966
967 let stripe_customer_id =
968 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
969 let stripe_subscription_id =
970 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
971
972 let usage_meters = usage_meters_by_user_id.get(&user_id);
973
974 for (model, mode) in &model_mode_combinations {
975 let Ok(model) =
976 llm_db.model(LanguageModelProvider::Anthropic, model)
977 else {
978 log::warn!("Failed to load model for user {user_id}: {model}");
979 continue;
980 };
981
982 let (price, meter_event_name) = match model.name.as_str() {
983 "claude-opus-4" => match mode {
984 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
985 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
986 },
987 "claude-sonnet-4" => match mode {
988 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
989 CompletionMode::Max => {
990 (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
991 }
992 },
993 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
994 "claude-3-7-sonnet" => match mode {
995 CompletionMode::Normal => {
996 (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
997 }
998 CompletionMode::Max => {
999 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1000 }
1001 },
1002 model_name => {
1003 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1004 }
1005 };
1006
1007 let model_requests = usage_meters
1008 .and_then(|usage_meters| {
1009 usage_meters
1010 .iter()
1011 .find(|meter| meter.model_id == model.id && meter.mode == *mode)
1012 })
1013 .map(|usage_meter| usage_meter.requests)
1014 .unwrap_or(0);
1015
1016 if model_requests > 0 {
1017 stripe_billing
1018 .subscribe_to_price(&stripe_subscription_id, price)
1019 .await?;
1020 }
1021
1022 stripe_billing
1023 .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
1024 .await
1025 .with_context(|| {
1026 format!(
1027 "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
1028 )
1029 })?;
1030 }
1031
1032 Ok(())
1033 })
1034 .await
1035 .log_err();
1036 }
1037
1038 log::info!(
1039 "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
1040 Utc::now() - started_at
1041 );
1042
1043 Ok(())
1044}