1use anyhow::{Context as _, bail};
2use axum::routing::put;
3use axum::{Extension, Json, Router, extract, routing::post};
4use chrono::{DateTime, SecondsFormat, Utc};
5use collections::{HashMap, HashSet};
6use reqwest::StatusCode;
7use sea_orm::ActiveValue;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use std::{str::FromStr, sync::Arc, time::Duration};
11use stripe::{
12 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
13 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
14 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
15 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
16 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
17 CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
18 PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
19};
20use util::{ResultExt, maybe};
21use zed_llm_client::LanguageModelProvider;
22
23use crate::api::events::SnowflakeRow;
24use crate::db::billing_subscription::{
25 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
26};
27use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
28use crate::rpc::{ResultExt as _, Server};
29use crate::stripe_client::{
30 StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
31 StripeSubscriptionId, UpdateCustomerParams,
32};
33use crate::{AppState, Error, Result};
34use crate::{db::UserId, llm::db::LlmDatabase};
35use crate::{
36 db::{
37 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
38 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
39 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
40 },
41 stripe_billing::StripeBilling,
42};
43
44pub fn router() -> Router {
45 Router::new()
46 .route("/billing/preferences", put(update_billing_preferences))
47 .route("/billing/subscriptions", post(create_billing_subscription))
48 .route(
49 "/billing/subscriptions/manage",
50 post(manage_billing_subscription),
51 )
52 .route(
53 "/billing/subscriptions/sync",
54 post(sync_billing_subscription),
55 )
56}
57
58#[derive(Debug, Serialize)]
59struct BillingPreferencesResponse {
60 trial_started_at: Option<String>,
61 max_monthly_llm_usage_spending_in_cents: i32,
62 model_request_overages_enabled: bool,
63 model_request_overages_spend_limit_in_cents: i32,
64}
65
66#[derive(Debug, Deserialize)]
67struct UpdateBillingPreferencesBody {
68 github_user_id: i32,
69 #[serde(default)]
70 max_monthly_llm_usage_spending_in_cents: i32,
71 #[serde(default)]
72 model_request_overages_enabled: bool,
73 #[serde(default)]
74 model_request_overages_spend_limit_in_cents: i32,
75}
76
77async fn update_billing_preferences(
78 Extension(app): Extension<Arc<AppState>>,
79 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
80 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
81) -> Result<Json<BillingPreferencesResponse>> {
82 let user = app
83 .db
84 .get_user_by_github_user_id(body.github_user_id)
85 .await?
86 .context("user not found")?;
87
88 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
89
90 let max_monthly_llm_usage_spending_in_cents =
91 body.max_monthly_llm_usage_spending_in_cents.max(0);
92 let model_request_overages_spend_limit_in_cents =
93 body.model_request_overages_spend_limit_in_cents.max(0);
94
95 let billing_preferences =
96 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
97 app.db
98 .update_billing_preferences(
99 user.id,
100 &UpdateBillingPreferencesParams {
101 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
102 max_monthly_llm_usage_spending_in_cents,
103 ),
104 model_request_overages_enabled: ActiveValue::set(
105 body.model_request_overages_enabled,
106 ),
107 model_request_overages_spend_limit_in_cents: ActiveValue::set(
108 model_request_overages_spend_limit_in_cents,
109 ),
110 },
111 )
112 .await?
113 } else {
114 app.db
115 .create_billing_preferences(
116 user.id,
117 &crate::db::CreateBillingPreferencesParams {
118 max_monthly_llm_usage_spending_in_cents,
119 model_request_overages_enabled: body.model_request_overages_enabled,
120 model_request_overages_spend_limit_in_cents,
121 },
122 )
123 .await?
124 };
125
126 SnowflakeRow::new(
127 "Billing Preferences Updated",
128 Some(user.metrics_id),
129 user.admin,
130 None,
131 json!({
132 "user_id": user.id,
133 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
134 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
135 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
136 }),
137 )
138 .write(&app.kinesis_client, &app.config.kinesis_stream)
139 .await
140 .log_err();
141
142 rpc_server.refresh_llm_tokens_for_user(user.id).await;
143
144 Ok(Json(BillingPreferencesResponse {
145 trial_started_at: billing_customer
146 .and_then(|billing_customer| billing_customer.trial_started_at)
147 .map(|trial_started_at| {
148 trial_started_at
149 .and_utc()
150 .to_rfc3339_opts(SecondsFormat::Millis, true)
151 }),
152 max_monthly_llm_usage_spending_in_cents: billing_preferences
153 .max_monthly_llm_usage_spending_in_cents,
154 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
155 model_request_overages_spend_limit_in_cents: billing_preferences
156 .model_request_overages_spend_limit_in_cents,
157 }))
158}
159
160#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
161#[serde(rename_all = "snake_case")]
162enum ProductCode {
163 ZedPro,
164 ZedProTrial,
165}
166
167#[derive(Debug, Deserialize)]
168struct CreateBillingSubscriptionBody {
169 github_user_id: i32,
170 product: ProductCode,
171}
172
173#[derive(Debug, Serialize)]
174struct CreateBillingSubscriptionResponse {
175 checkout_session_url: String,
176}
177
178/// Initiates a Stripe Checkout session for creating a billing subscription.
179async fn create_billing_subscription(
180 Extension(app): Extension<Arc<AppState>>,
181 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
182) -> Result<Json<CreateBillingSubscriptionResponse>> {
183 let user = app
184 .db
185 .get_user_by_github_user_id(body.github_user_id)
186 .await?
187 .context("user not found")?;
188
189 let Some(stripe_billing) = app.stripe_billing.clone() else {
190 log::error!("failed to retrieve Stripe billing object");
191 Err(Error::http(
192 StatusCode::NOT_IMPLEMENTED,
193 "not supported".into(),
194 ))?
195 };
196
197 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
198 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
199 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
200
201 if !is_checkout_allowed {
202 return Err(Error::http(
203 StatusCode::CONFLICT,
204 "user already has an active subscription".into(),
205 ));
206 }
207 }
208
209 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
210 if let Some(existing_billing_customer) = &existing_billing_customer {
211 if existing_billing_customer.has_overdue_invoices {
212 return Err(Error::http(
213 StatusCode::PAYMENT_REQUIRED,
214 "user has overdue invoices".into(),
215 ));
216 }
217 }
218
219 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
220 let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
221 if let Some(email) = user.email_address.as_deref() {
222 stripe_billing
223 .client()
224 .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
225 .await
226 // Update of email address is best-effort - continue checkout even if it fails
227 .context("error updating stripe customer email address")
228 .log_err();
229 }
230 customer_id
231 } else {
232 stripe_billing
233 .find_or_create_customer_by_email(user.email_address.as_deref())
234 .await?
235 };
236
237 let success_url = format!(
238 "{}/account?checkout_complete=1",
239 app.config.zed_dot_dev_url()
240 );
241
242 let checkout_session_url = match body.product {
243 ProductCode::ZedPro => {
244 stripe_billing
245 .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
246 .await?
247 }
248 ProductCode::ZedProTrial => {
249 if let Some(existing_billing_customer) = &existing_billing_customer {
250 if existing_billing_customer.trial_started_at.is_some() {
251 return Err(Error::http(
252 StatusCode::FORBIDDEN,
253 "user already used free trial".into(),
254 ));
255 }
256 }
257
258 let feature_flags = app.db.get_user_flags(user.id).await?;
259
260 stripe_billing
261 .checkout_with_zed_pro_trial(
262 &customer_id,
263 &user.github_login,
264 feature_flags,
265 &success_url,
266 )
267 .await?
268 }
269 };
270
271 Ok(Json(CreateBillingSubscriptionResponse {
272 checkout_session_url,
273 }))
274}
275
276#[derive(Debug, PartialEq, Deserialize)]
277#[serde(rename_all = "snake_case")]
278enum ManageSubscriptionIntent {
279 /// The user intends to manage their subscription.
280 ///
281 /// This will open the Stripe billing portal without putting the user in a specific flow.
282 ManageSubscription,
283 /// The user intends to update their payment method.
284 UpdatePaymentMethod,
285 /// The user intends to upgrade to Zed Pro.
286 UpgradeToPro,
287 /// The user intends to cancel their subscription.
288 Cancel,
289 /// The user intends to stop the cancellation of their subscription.
290 StopCancellation,
291}
292
293#[derive(Debug, Deserialize)]
294struct ManageBillingSubscriptionBody {
295 github_user_id: i32,
296 intent: ManageSubscriptionIntent,
297 /// The ID of the subscription to manage.
298 subscription_id: BillingSubscriptionId,
299 redirect_to: Option<String>,
300}
301
302#[derive(Debug, Serialize)]
303struct ManageBillingSubscriptionResponse {
304 billing_portal_session_url: Option<String>,
305}
306
307/// Initiates a Stripe customer portal session for managing a billing subscription.
308async fn manage_billing_subscription(
309 Extension(app): Extension<Arc<AppState>>,
310 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
311) -> Result<Json<ManageBillingSubscriptionResponse>> {
312 let user = app
313 .db
314 .get_user_by_github_user_id(body.github_user_id)
315 .await?
316 .context("user not found")?;
317
318 let Some(stripe_client) = app.real_stripe_client.clone() else {
319 log::error!("failed to retrieve Stripe client");
320 Err(Error::http(
321 StatusCode::NOT_IMPLEMENTED,
322 "not supported".into(),
323 ))?
324 };
325
326 let Some(stripe_billing) = app.stripe_billing.clone() else {
327 log::error!("failed to retrieve Stripe billing object");
328 Err(Error::http(
329 StatusCode::NOT_IMPLEMENTED,
330 "not supported".into(),
331 ))?
332 };
333
334 let customer = app
335 .db
336 .get_billing_customer_by_user_id(user.id)
337 .await?
338 .context("billing customer not found")?;
339 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
340 .context("failed to parse customer ID")?;
341
342 let subscription = app
343 .db
344 .get_billing_subscription_by_id(body.subscription_id)
345 .await?
346 .context("subscription not found")?;
347 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
348 .context("failed to parse subscription ID")?;
349
350 if body.intent == ManageSubscriptionIntent::StopCancellation {
351 let updated_stripe_subscription = Subscription::update(
352 &stripe_client,
353 &subscription_id,
354 stripe::UpdateSubscription {
355 cancel_at_period_end: Some(false),
356 ..Default::default()
357 },
358 )
359 .await?;
360
361 app.db
362 .update_billing_subscription(
363 subscription.id,
364 &UpdateBillingSubscriptionParams {
365 stripe_cancel_at: ActiveValue::set(
366 updated_stripe_subscription
367 .cancel_at
368 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
369 .map(|time| time.naive_utc()),
370 ),
371 ..Default::default()
372 },
373 )
374 .await?;
375
376 return Ok(Json(ManageBillingSubscriptionResponse {
377 billing_portal_session_url: None,
378 }));
379 }
380
381 let flow = match body.intent {
382 ManageSubscriptionIntent::ManageSubscription => None,
383 ManageSubscriptionIntent::UpgradeToPro => {
384 let zed_pro_price_id: stripe::PriceId =
385 stripe_billing.zed_pro_price_id().await?.try_into()?;
386 let zed_free_price_id: stripe::PriceId =
387 stripe_billing.zed_free_price_id().await?.try_into()?;
388
389 let stripe_subscription =
390 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
391
392 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
393 && stripe_subscription.items.data.iter().any(|item| {
394 item.price
395 .as_ref()
396 .map_or(false, |price| price.id == zed_pro_price_id)
397 });
398 if is_on_zed_pro_trial {
399 let payment_methods = PaymentMethod::list(
400 &stripe_client,
401 &stripe::ListPaymentMethods {
402 customer: Some(stripe_subscription.customer.id()),
403 ..Default::default()
404 },
405 )
406 .await?;
407
408 let has_payment_method = !payment_methods.data.is_empty();
409 if !has_payment_method {
410 return Err(Error::http(
411 StatusCode::BAD_REQUEST,
412 "missing payment method".into(),
413 ));
414 }
415
416 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
417 Subscription::update(
418 &stripe_client,
419 &stripe_subscription.id,
420 stripe::UpdateSubscription {
421 trial_end: Some(stripe::Scheduled::now()),
422 ..Default::default()
423 },
424 )
425 .await?;
426
427 return Ok(Json(ManageBillingSubscriptionResponse {
428 billing_portal_session_url: None,
429 }));
430 }
431
432 let subscription_item_to_update = stripe_subscription
433 .items
434 .data
435 .iter()
436 .find_map(|item| {
437 let price = item.price.as_ref()?;
438
439 if price.id == zed_free_price_id {
440 Some(item.id.clone())
441 } else {
442 None
443 }
444 })
445 .context("No subscription item to update")?;
446
447 Some(CreateBillingPortalSessionFlowData {
448 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
449 subscription_update_confirm: Some(
450 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
451 subscription: subscription.stripe_subscription_id,
452 items: vec![
453 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
454 id: subscription_item_to_update.to_string(),
455 price: Some(zed_pro_price_id.to_string()),
456 quantity: Some(1),
457 },
458 ],
459 discounts: None,
460 },
461 ),
462 ..Default::default()
463 })
464 }
465 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
466 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
467 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
468 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
469 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
470 return_url: format!(
471 "{}{path}",
472 app.config.zed_dot_dev_url(),
473 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
474 ),
475 }),
476 ..Default::default()
477 }),
478 ..Default::default()
479 }),
480 ManageSubscriptionIntent::Cancel => {
481 if subscription.kind == Some(SubscriptionKind::ZedFree) {
482 return Err(Error::http(
483 StatusCode::BAD_REQUEST,
484 "free subscription cannot be canceled".into(),
485 ));
486 }
487
488 Some(CreateBillingPortalSessionFlowData {
489 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
490 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
491 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
492 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
493 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
494 }),
495 ..Default::default()
496 }),
497 subscription_cancel: Some(
498 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
499 subscription: subscription.stripe_subscription_id,
500 retention: None,
501 },
502 ),
503 ..Default::default()
504 })
505 }
506 ManageSubscriptionIntent::StopCancellation => unreachable!(),
507 };
508
509 let mut params = CreateBillingPortalSession::new(customer_id);
510 params.flow_data = flow;
511 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
512 params.return_url = Some(&return_url);
513
514 let session = BillingPortalSession::create(&stripe_client, params).await?;
515
516 Ok(Json(ManageBillingSubscriptionResponse {
517 billing_portal_session_url: Some(session.url),
518 }))
519}
520
521#[derive(Debug, Deserialize)]
522struct SyncBillingSubscriptionBody {
523 github_user_id: i32,
524}
525
526#[derive(Debug, Serialize)]
527struct SyncBillingSubscriptionResponse {
528 stripe_customer_id: String,
529}
530
531async fn sync_billing_subscription(
532 Extension(app): Extension<Arc<AppState>>,
533 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
534) -> Result<Json<SyncBillingSubscriptionResponse>> {
535 let Some(stripe_client) = app.stripe_client.clone() else {
536 log::error!("failed to retrieve Stripe client");
537 Err(Error::http(
538 StatusCode::NOT_IMPLEMENTED,
539 "not supported".into(),
540 ))?
541 };
542
543 let user = app
544 .db
545 .get_user_by_github_user_id(body.github_user_id)
546 .await?
547 .context("user not found")?;
548
549 let billing_customer = app
550 .db
551 .get_billing_customer_by_user_id(user.id)
552 .await?
553 .context("billing customer not found")?;
554 let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
555
556 let subscriptions = stripe_client
557 .list_subscriptions_for_customer(&stripe_customer_id)
558 .await?;
559
560 for subscription in subscriptions {
561 let subscription_id = subscription.id.clone();
562
563 sync_subscription(&app, &stripe_client, subscription)
564 .await
565 .with_context(|| {
566 format!(
567 "failed to sync subscription {subscription_id} for user {}",
568 user.id,
569 )
570 })?;
571 }
572
573 Ok(Json(SyncBillingSubscriptionResponse {
574 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
575 }))
576}
577
578/// The amount of time we wait in between each poll of Stripe events.
579///
580/// This value should strike a balance between:
581/// 1. Being short enough that we update quickly when something in Stripe changes
582/// 2. Being long enough that we don't eat into our rate limits.
583///
584/// As a point of reference, the Sequin folks say they have this at **500ms**:
585///
586/// > We poll the Stripe /events endpoint every 500ms per account
587/// >
588/// > — https://blog.sequinstream.com/events-not-webhooks/
589const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
590
591/// The maximum number of events to return per page.
592///
593/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
594///
595/// > Limit can range between 1 and 100, and the default is 10.
596const EVENTS_LIMIT_PER_PAGE: u64 = 100;
597
598/// The number of pages consisting entirely of already-processed events that we
599/// will see before we stop retrieving events.
600///
601/// This is used to prevent over-fetching the Stripe events API for events we've
602/// already seen and processed.
603const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
604
605/// Polls the Stripe events API periodically to reconcile the records in our
606/// database with the data in Stripe.
607pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
608 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
609 log::warn!("failed to retrieve Stripe client");
610 return;
611 };
612 let Some(stripe_client) = app.stripe_client.clone() else {
613 log::warn!("failed to retrieve Stripe client");
614 return;
615 };
616
617 let executor = app.executor.clone();
618 executor.spawn_detached({
619 let executor = executor.clone();
620 async move {
621 loop {
622 poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
623 .await
624 .log_err();
625
626 executor.sleep(POLL_EVENTS_INTERVAL).await;
627 }
628 }
629 });
630}
631
632async fn poll_stripe_events(
633 app: &Arc<AppState>,
634 rpc_server: &Arc<Server>,
635 stripe_client: &Arc<dyn StripeClient>,
636 real_stripe_client: &stripe::Client,
637) -> anyhow::Result<()> {
638 fn event_type_to_string(event_type: EventType) -> String {
639 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
640 // so we need to unquote it.
641 event_type.to_string().trim_matches('"').to_string()
642 }
643
644 let event_types = [
645 EventType::CustomerCreated,
646 EventType::CustomerUpdated,
647 EventType::CustomerSubscriptionCreated,
648 EventType::CustomerSubscriptionUpdated,
649 EventType::CustomerSubscriptionPaused,
650 EventType::CustomerSubscriptionResumed,
651 EventType::CustomerSubscriptionDeleted,
652 ]
653 .into_iter()
654 .map(event_type_to_string)
655 .collect::<Vec<_>>();
656
657 let mut pages_of_already_processed_events = 0;
658 let mut unprocessed_events = Vec::new();
659
660 log::info!(
661 "Stripe events: starting retrieval for {}",
662 event_types.join(", ")
663 );
664 let mut params = ListEvents::new();
665 params.types = Some(event_types.clone());
666 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
667
668 let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
669 .await?
670 .paginate(params);
671
672 loop {
673 let processed_event_ids = {
674 let event_ids = event_pages
675 .page
676 .data
677 .iter()
678 .map(|event| event.id.as_str())
679 .collect::<Vec<_>>();
680 app.db
681 .get_processed_stripe_events_by_event_ids(&event_ids)
682 .await?
683 .into_iter()
684 .map(|event| event.stripe_event_id)
685 .collect::<Vec<_>>()
686 };
687
688 let mut processed_events_in_page = 0;
689 let events_in_page = event_pages.page.data.len();
690 for event in &event_pages.page.data {
691 if processed_event_ids.contains(&event.id.to_string()) {
692 processed_events_in_page += 1;
693 log::debug!("Stripe events: already processed '{}', skipping", event.id);
694 } else {
695 unprocessed_events.push(event.clone());
696 }
697 }
698
699 if processed_events_in_page == events_in_page {
700 pages_of_already_processed_events += 1;
701 }
702
703 if event_pages.page.has_more {
704 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
705 {
706 log::info!(
707 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
708 );
709 break;
710 } else {
711 log::info!("Stripe events: retrieving next page");
712 event_pages = event_pages.next(&real_stripe_client).await?;
713 }
714 } else {
715 break;
716 }
717 }
718
719 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
720
721 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
722 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
723
724 for event in unprocessed_events {
725 let event_id = event.id.clone();
726 let processed_event_params = CreateProcessedStripeEventParams {
727 stripe_event_id: event.id.to_string(),
728 stripe_event_type: event_type_to_string(event.type_),
729 stripe_event_created_timestamp: event.created,
730 };
731
732 // If the event has happened too far in the past, we don't want to
733 // process it and risk overwriting other more-recent updates.
734 //
735 // 1 day was chosen arbitrarily. This could be made longer or shorter.
736 let one_day = Duration::from_secs(24 * 60 * 60);
737 let a_day_ago = Utc::now() - one_day;
738 if a_day_ago.timestamp() > event.created {
739 log::info!(
740 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
741 event_id
742 );
743 app.db
744 .create_processed_stripe_event(&processed_event_params)
745 .await?;
746
747 continue;
748 }
749
750 let process_result = match event.type_ {
751 EventType::CustomerCreated | EventType::CustomerUpdated => {
752 handle_customer_event(app, real_stripe_client, event).await
753 }
754 EventType::CustomerSubscriptionCreated
755 | EventType::CustomerSubscriptionUpdated
756 | EventType::CustomerSubscriptionPaused
757 | EventType::CustomerSubscriptionResumed
758 | EventType::CustomerSubscriptionDeleted => {
759 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
760 }
761 _ => Ok(()),
762 };
763
764 if let Some(()) = process_result
765 .with_context(|| format!("failed to process event {event_id} successfully"))
766 .log_err()
767 {
768 app.db
769 .create_processed_stripe_event(&processed_event_params)
770 .await?;
771 }
772 }
773
774 Ok(())
775}
776
777async fn handle_customer_event(
778 app: &Arc<AppState>,
779 _stripe_client: &stripe::Client,
780 event: stripe::Event,
781) -> anyhow::Result<()> {
782 let EventObject::Customer(customer) = event.data.object else {
783 bail!("unexpected event payload for {}", event.id);
784 };
785
786 log::info!("handling Stripe {} event: {}", event.type_, event.id);
787
788 let Some(email) = customer.email else {
789 log::info!("Stripe customer has no email: skipping");
790 return Ok(());
791 };
792
793 let Some(user) = app.db.get_user_by_email(&email).await? else {
794 log::info!("no user found for email: skipping");
795 return Ok(());
796 };
797
798 if let Some(existing_customer) = app
799 .db
800 .get_billing_customer_by_stripe_customer_id(&customer.id)
801 .await?
802 {
803 app.db
804 .update_billing_customer(
805 existing_customer.id,
806 &UpdateBillingCustomerParams {
807 // For now we just leave the information as-is, as it is not
808 // likely to change.
809 ..Default::default()
810 },
811 )
812 .await?;
813 } else {
814 app.db
815 .create_billing_customer(&CreateBillingCustomerParams {
816 user_id: user.id,
817 stripe_customer_id: customer.id.to_string(),
818 })
819 .await?;
820 }
821
822 Ok(())
823}
824
825async fn sync_subscription(
826 app: &Arc<AppState>,
827 stripe_client: &Arc<dyn StripeClient>,
828 subscription: StripeSubscription,
829) -> anyhow::Result<billing_customer::Model> {
830 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
831 stripe_billing
832 .determine_subscription_kind(&subscription)
833 .await
834 } else {
835 None
836 };
837
838 let billing_customer =
839 find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
840 .await?
841 .context("billing customer not found")?;
842
843 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
844 if subscription.status == SubscriptionStatus::Trialing {
845 let current_period_start =
846 DateTime::from_timestamp(subscription.current_period_start, 0)
847 .context("No trial subscription period start")?;
848
849 app.db
850 .update_billing_customer(
851 billing_customer.id,
852 &UpdateBillingCustomerParams {
853 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
854 ..Default::default()
855 },
856 )
857 .await?;
858 }
859 }
860
861 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
862 && subscription
863 .cancellation_details
864 .as_ref()
865 .and_then(|details| details.reason)
866 .map_or(false, |reason| {
867 reason == StripeCancellationDetailsReason::PaymentFailed
868 });
869
870 if was_canceled_due_to_payment_failure {
871 app.db
872 .update_billing_customer(
873 billing_customer.id,
874 &UpdateBillingCustomerParams {
875 has_overdue_invoices: ActiveValue::set(true),
876 ..Default::default()
877 },
878 )
879 .await?;
880 }
881
882 if let Some(existing_subscription) = app
883 .db
884 .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
885 .await?
886 {
887 app.db
888 .update_billing_subscription(
889 existing_subscription.id,
890 &UpdateBillingSubscriptionParams {
891 billing_customer_id: ActiveValue::set(billing_customer.id),
892 kind: ActiveValue::set(subscription_kind),
893 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
894 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
895 stripe_cancel_at: ActiveValue::set(
896 subscription
897 .cancel_at
898 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
899 .map(|time| time.naive_utc()),
900 ),
901 stripe_cancellation_reason: ActiveValue::set(
902 subscription
903 .cancellation_details
904 .and_then(|details| details.reason)
905 .map(|reason| reason.into()),
906 ),
907 stripe_current_period_start: ActiveValue::set(Some(
908 subscription.current_period_start,
909 )),
910 stripe_current_period_end: ActiveValue::set(Some(
911 subscription.current_period_end,
912 )),
913 },
914 )
915 .await?;
916 } else {
917 if let Some(existing_subscription) = app
918 .db
919 .get_active_billing_subscription(billing_customer.user_id)
920 .await?
921 {
922 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
923 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
924 {
925 let stripe_subscription_id = StripeSubscriptionId(
926 existing_subscription.stripe_subscription_id.clone().into(),
927 );
928
929 stripe_client
930 .cancel_subscription(&stripe_subscription_id)
931 .await?;
932 } else {
933 // If the user already has an active billing subscription, ignore the
934 // event and return an `Ok` to signal that it was processed
935 // successfully.
936 //
937 // There is the possibility that this could cause us to not create a
938 // subscription in the following scenario:
939 //
940 // 1. User has an active subscription A
941 // 2. User cancels subscription A
942 // 3. User creates a new subscription B
943 // 4. We process the new subscription B before the cancellation of subscription A
944 // 5. User ends up with no subscriptions
945 //
946 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
947
948 log::info!(
949 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
950 user_id = billing_customer.user_id,
951 subscription_id = subscription.id
952 );
953 return Ok(billing_customer);
954 }
955 }
956
957 app.db
958 .create_billing_subscription(&CreateBillingSubscriptionParams {
959 billing_customer_id: billing_customer.id,
960 kind: subscription_kind,
961 stripe_subscription_id: subscription.id.to_string(),
962 stripe_subscription_status: subscription.status.into(),
963 stripe_cancellation_reason: subscription
964 .cancellation_details
965 .and_then(|details| details.reason)
966 .map(|reason| reason.into()),
967 stripe_current_period_start: Some(subscription.current_period_start),
968 stripe_current_period_end: Some(subscription.current_period_end),
969 })
970 .await?;
971 }
972
973 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
974 if subscription.status == SubscriptionStatus::Canceled
975 || subscription.status == SubscriptionStatus::Paused
976 {
977 let already_has_active_billing_subscription = app
978 .db
979 .has_active_billing_subscription(billing_customer.user_id)
980 .await?;
981 if !already_has_active_billing_subscription {
982 let stripe_customer_id =
983 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
984
985 stripe_billing
986 .subscribe_to_zed_free(stripe_customer_id)
987 .await?;
988 }
989 }
990 }
991
992 Ok(billing_customer)
993}
994
995async fn handle_customer_subscription_event(
996 app: &Arc<AppState>,
997 rpc_server: &Arc<Server>,
998 stripe_client: &Arc<dyn StripeClient>,
999 event: stripe::Event,
1000) -> anyhow::Result<()> {
1001 let EventObject::Subscription(subscription) = event.data.object else {
1002 bail!("unexpected event payload for {}", event.id);
1003 };
1004
1005 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1006
1007 let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
1008
1009 // When the user's subscription changes, push down any changes to their plan.
1010 rpc_server
1011 .update_plan_for_user(billing_customer.user_id)
1012 .await
1013 .trace_err();
1014
1015 // When the user's subscription changes, we want to refresh their LLM tokens
1016 // to either grant/revoke access.
1017 rpc_server
1018 .refresh_llm_tokens_for_user(billing_customer.user_id)
1019 .await;
1020
1021 Ok(())
1022}
1023
1024impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1025 fn from(value: SubscriptionStatus) -> Self {
1026 match value {
1027 SubscriptionStatus::Incomplete => Self::Incomplete,
1028 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1029 SubscriptionStatus::Trialing => Self::Trialing,
1030 SubscriptionStatus::Active => Self::Active,
1031 SubscriptionStatus::PastDue => Self::PastDue,
1032 SubscriptionStatus::Canceled => Self::Canceled,
1033 SubscriptionStatus::Unpaid => Self::Unpaid,
1034 SubscriptionStatus::Paused => Self::Paused,
1035 }
1036 }
1037}
1038
1039impl From<CancellationDetailsReason> for StripeCancellationReason {
1040 fn from(value: CancellationDetailsReason) -> Self {
1041 match value {
1042 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1043 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1044 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1045 }
1046 }
1047}
1048
1049/// Finds or creates a billing customer using the provided customer.
1050pub async fn find_or_create_billing_customer(
1051 app: &Arc<AppState>,
1052 stripe_client: &dyn StripeClient,
1053 customer_id: &StripeCustomerId,
1054) -> anyhow::Result<Option<billing_customer::Model>> {
1055 // If we already have a billing customer record associated with the Stripe customer,
1056 // there's nothing more we need to do.
1057 if let Some(billing_customer) = app
1058 .db
1059 .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
1060 .await?
1061 {
1062 return Ok(Some(billing_customer));
1063 }
1064
1065 let customer = stripe_client.get_customer(customer_id).await?;
1066
1067 let Some(email) = customer.email else {
1068 return Ok(None);
1069 };
1070
1071 let Some(user) = app.db.get_user_by_email(&email).await? else {
1072 return Ok(None);
1073 };
1074
1075 let billing_customer = app
1076 .db
1077 .create_billing_customer(&CreateBillingCustomerParams {
1078 user_id: user.id,
1079 stripe_customer_id: customer.id.to_string(),
1080 })
1081 .await?;
1082
1083 Ok(Some(billing_customer))
1084}
1085
1086const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1087
1088pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1089 let Some(stripe_billing) = app.stripe_billing.clone() else {
1090 log::warn!("failed to retrieve Stripe billing object");
1091 return;
1092 };
1093 let Some(llm_db) = app.llm_db.clone() else {
1094 log::warn!("failed to retrieve LLM database");
1095 return;
1096 };
1097
1098 let executor = app.executor.clone();
1099 executor.spawn_detached({
1100 let executor = executor.clone();
1101 async move {
1102 loop {
1103 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1104 .await
1105 .context("failed to sync LLM request usage to Stripe")
1106 .trace_err();
1107 executor
1108 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1109 .await;
1110 }
1111 }
1112 });
1113}
1114
1115async fn sync_model_request_usage_with_stripe(
1116 app: &Arc<AppState>,
1117 llm_db: &Arc<LlmDatabase>,
1118 stripe_billing: &Arc<StripeBilling>,
1119) -> anyhow::Result<()> {
1120 log::info!("Stripe usage sync: Starting");
1121 let started_at = Utc::now();
1122
1123 let staff_users = app.db.get_staff_users().await?;
1124 let staff_user_ids = staff_users
1125 .iter()
1126 .map(|user| user.id)
1127 .collect::<HashSet<UserId>>();
1128
1129 let usage_meters = llm_db
1130 .get_current_subscription_usage_meters(Utc::now())
1131 .await?;
1132 let mut usage_meters_by_user_id =
1133 HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
1134 for (usage_meter, usage) in usage_meters {
1135 let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
1136 meters.push(usage_meter);
1137 }
1138
1139 log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
1140 let get_zed_pro_subscriptions_started_at = Utc::now();
1141 let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
1142 log::info!(
1143 "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
1144 billing_subscriptions.len(),
1145 Utc::now() - get_zed_pro_subscriptions_started_at
1146 );
1147
1148 let claude_sonnet_4 = stripe_billing
1149 .find_price_by_lookup_key("claude-sonnet-4-requests")
1150 .await?;
1151 let claude_sonnet_4_max = stripe_billing
1152 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1153 .await?;
1154 let claude_opus_4 = stripe_billing
1155 .find_price_by_lookup_key("claude-opus-4-requests")
1156 .await?;
1157 let claude_opus_4_max = stripe_billing
1158 .find_price_by_lookup_key("claude-opus-4-requests-max")
1159 .await?;
1160 let claude_3_5_sonnet = stripe_billing
1161 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1162 .await?;
1163 let claude_3_7_sonnet = stripe_billing
1164 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1165 .await?;
1166 let claude_3_7_sonnet_max = stripe_billing
1167 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1168 .await?;
1169
1170 let model_mode_combinations = [
1171 ("claude-opus-4", CompletionMode::Max),
1172 ("claude-opus-4", CompletionMode::Normal),
1173 ("claude-sonnet-4", CompletionMode::Max),
1174 ("claude-sonnet-4", CompletionMode::Normal),
1175 ("claude-3-7-sonnet", CompletionMode::Max),
1176 ("claude-3-7-sonnet", CompletionMode::Normal),
1177 ("claude-3-5-sonnet", CompletionMode::Normal),
1178 ];
1179
1180 let billing_subscription_count = billing_subscriptions.len();
1181
1182 log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
1183
1184 for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
1185 maybe!(async {
1186 if staff_user_ids.contains(&user_id) {
1187 return anyhow::Ok(());
1188 }
1189
1190 let stripe_customer_id =
1191 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1192 let stripe_subscription_id =
1193 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1194
1195 let usage_meters = usage_meters_by_user_id.get(&user_id);
1196
1197 for (model, mode) in &model_mode_combinations {
1198 let Ok(model) =
1199 llm_db.model(LanguageModelProvider::Anthropic, model)
1200 else {
1201 log::warn!("Failed to load model for user {user_id}: {model}");
1202 continue;
1203 };
1204
1205 let (price, meter_event_name) = match model.name.as_str() {
1206 "claude-opus-4" => match mode {
1207 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1208 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1209 },
1210 "claude-sonnet-4" => match mode {
1211 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1212 CompletionMode::Max => {
1213 (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
1214 }
1215 },
1216 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1217 "claude-3-7-sonnet" => match mode {
1218 CompletionMode::Normal => {
1219 (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
1220 }
1221 CompletionMode::Max => {
1222 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1223 }
1224 },
1225 model_name => {
1226 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1227 }
1228 };
1229
1230 let model_requests = usage_meters
1231 .and_then(|usage_meters| {
1232 usage_meters
1233 .iter()
1234 .find(|meter| meter.model_id == model.id && meter.mode == *mode)
1235 })
1236 .map(|usage_meter| usage_meter.requests)
1237 .unwrap_or(0);
1238
1239 if model_requests > 0 {
1240 stripe_billing
1241 .subscribe_to_price(&stripe_subscription_id, price)
1242 .await?;
1243 }
1244
1245 stripe_billing
1246 .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
1247 .await
1248 .with_context(|| {
1249 format!(
1250 "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
1251 )
1252 })?;
1253 }
1254
1255 Ok(())
1256 })
1257 .await
1258 .log_err();
1259 }
1260
1261 log::info!(
1262 "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
1263 Utc::now() - started_at
1264 );
1265
1266 Ok(())
1267}