1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
21 Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
31use crate::rpc::{ResultExt as _, Server};
32use crate::{AppState, Error, Result};
33use crate::{db::UserId, llm::db::LlmDatabase};
34use crate::{
35 db::{
36 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
37 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
38 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
39 },
40 stripe_billing::StripeBilling,
41};
42
43pub fn router() -> Router {
44 Router::new()
45 .route(
46 "/billing/preferences",
47 get(get_billing_preferences).put(update_billing_preferences),
48 )
49 .route(
50 "/billing/subscriptions",
51 get(list_billing_subscriptions).post(create_billing_subscription),
52 )
53 .route(
54 "/billing/subscriptions/manage",
55 post(manage_billing_subscription),
56 )
57 .route(
58 "/billing/subscriptions/migrate",
59 post(migrate_to_new_billing),
60 )
61 .route(
62 "/billing/subscriptions/sync",
63 post(sync_billing_subscription),
64 )
65 .route("/billing/usage", get(get_current_usage))
66}
67
68#[derive(Debug, Deserialize)]
69struct GetBillingPreferencesParams {
70 github_user_id: i32,
71}
72
73#[derive(Debug, Serialize)]
74struct BillingPreferencesResponse {
75 trial_started_at: Option<String>,
76 max_monthly_llm_usage_spending_in_cents: i32,
77 model_request_overages_enabled: bool,
78 model_request_overages_spend_limit_in_cents: i32,
79}
80
81async fn get_billing_preferences(
82 Extension(app): Extension<Arc<AppState>>,
83 Query(params): Query<GetBillingPreferencesParams>,
84) -> Result<Json<BillingPreferencesResponse>> {
85 let user = app
86 .db
87 .get_user_by_github_user_id(params.github_user_id)
88 .await?
89 .context("user not found")?;
90
91 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
92 let preferences = app.db.get_billing_preferences(user.id).await?;
93
94 Ok(Json(BillingPreferencesResponse {
95 trial_started_at: billing_customer
96 .and_then(|billing_customer| billing_customer.trial_started_at)
97 .map(|trial_started_at| {
98 trial_started_at
99 .and_utc()
100 .to_rfc3339_opts(SecondsFormat::Millis, true)
101 }),
102 max_monthly_llm_usage_spending_in_cents: preferences
103 .as_ref()
104 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
105 preferences.max_monthly_llm_usage_spending_in_cents
106 }),
107 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
108 preferences.model_request_overages_enabled
109 }),
110 model_request_overages_spend_limit_in_cents: preferences
111 .as_ref()
112 .map_or(0, |preferences| {
113 preferences.model_request_overages_spend_limit_in_cents
114 }),
115 }))
116}
117
118#[derive(Debug, Deserialize)]
119struct UpdateBillingPreferencesBody {
120 github_user_id: i32,
121 #[serde(default)]
122 max_monthly_llm_usage_spending_in_cents: i32,
123 #[serde(default)]
124 model_request_overages_enabled: bool,
125 #[serde(default)]
126 model_request_overages_spend_limit_in_cents: i32,
127}
128
129async fn update_billing_preferences(
130 Extension(app): Extension<Arc<AppState>>,
131 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
132 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
133) -> Result<Json<BillingPreferencesResponse>> {
134 let user = app
135 .db
136 .get_user_by_github_user_id(body.github_user_id)
137 .await?
138 .context("user not found")?;
139
140 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
141
142 let max_monthly_llm_usage_spending_in_cents =
143 body.max_monthly_llm_usage_spending_in_cents.max(0);
144 let model_request_overages_spend_limit_in_cents =
145 body.model_request_overages_spend_limit_in_cents.max(0);
146
147 let billing_preferences =
148 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
149 app.db
150 .update_billing_preferences(
151 user.id,
152 &UpdateBillingPreferencesParams {
153 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
154 max_monthly_llm_usage_spending_in_cents,
155 ),
156 model_request_overages_enabled: ActiveValue::set(
157 body.model_request_overages_enabled,
158 ),
159 model_request_overages_spend_limit_in_cents: ActiveValue::set(
160 model_request_overages_spend_limit_in_cents,
161 ),
162 },
163 )
164 .await?
165 } else {
166 app.db
167 .create_billing_preferences(
168 user.id,
169 &crate::db::CreateBillingPreferencesParams {
170 max_monthly_llm_usage_spending_in_cents,
171 model_request_overages_enabled: body.model_request_overages_enabled,
172 model_request_overages_spend_limit_in_cents,
173 },
174 )
175 .await?
176 };
177
178 SnowflakeRow::new(
179 "Billing Preferences Updated",
180 Some(user.metrics_id),
181 user.admin,
182 None,
183 json!({
184 "user_id": user.id,
185 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
186 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
187 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
188 }),
189 )
190 .write(&app.kinesis_client, &app.config.kinesis_stream)
191 .await
192 .log_err();
193
194 rpc_server.refresh_llm_tokens_for_user(user.id).await;
195
196 Ok(Json(BillingPreferencesResponse {
197 trial_started_at: billing_customer
198 .and_then(|billing_customer| billing_customer.trial_started_at)
199 .map(|trial_started_at| {
200 trial_started_at
201 .and_utc()
202 .to_rfc3339_opts(SecondsFormat::Millis, true)
203 }),
204 max_monthly_llm_usage_spending_in_cents: billing_preferences
205 .max_monthly_llm_usage_spending_in_cents,
206 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
207 model_request_overages_spend_limit_in_cents: billing_preferences
208 .model_request_overages_spend_limit_in_cents,
209 }))
210}
211
212#[derive(Debug, Deserialize)]
213struct ListBillingSubscriptionsParams {
214 github_user_id: i32,
215}
216
217#[derive(Debug, Serialize)]
218struct BillingSubscriptionJson {
219 id: BillingSubscriptionId,
220 name: String,
221 status: StripeSubscriptionStatus,
222 trial_end_at: Option<String>,
223 cancel_at: Option<String>,
224 /// Whether this subscription can be canceled.
225 is_cancelable: bool,
226}
227
228#[derive(Debug, Serialize)]
229struct ListBillingSubscriptionsResponse {
230 subscriptions: Vec<BillingSubscriptionJson>,
231}
232
233async fn list_billing_subscriptions(
234 Extension(app): Extension<Arc<AppState>>,
235 Query(params): Query<ListBillingSubscriptionsParams>,
236) -> Result<Json<ListBillingSubscriptionsResponse>> {
237 let user = app
238 .db
239 .get_user_by_github_user_id(params.github_user_id)
240 .await?
241 .context("user not found")?;
242
243 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
244
245 Ok(Json(ListBillingSubscriptionsResponse {
246 subscriptions: subscriptions
247 .into_iter()
248 .map(|subscription| BillingSubscriptionJson {
249 id: subscription.id,
250 name: match subscription.kind {
251 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
252 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
253 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
254 None => "Zed LLM Usage".to_string(),
255 },
256 status: subscription.stripe_subscription_status,
257 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
258 maybe!({
259 let end_at = subscription.stripe_current_period_end?;
260 let end_at = DateTime::from_timestamp(end_at, 0)?;
261
262 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
263 })
264 } else {
265 None
266 },
267 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
268 cancel_at
269 .and_utc()
270 .to_rfc3339_opts(SecondsFormat::Millis, true)
271 }),
272 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
273 && subscription.stripe_subscription_status.is_cancelable()
274 && subscription.stripe_cancel_at.is_none(),
275 })
276 .collect(),
277 }))
278}
279
280#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
281#[serde(rename_all = "snake_case")]
282enum ProductCode {
283 ZedPro,
284 ZedProTrial,
285 ZedFree,
286}
287
288#[derive(Debug, Deserialize)]
289struct CreateBillingSubscriptionBody {
290 github_user_id: i32,
291 product: ProductCode,
292}
293
294#[derive(Debug, Serialize)]
295struct CreateBillingSubscriptionResponse {
296 checkout_session_url: String,
297}
298
299/// Initiates a Stripe Checkout session for creating a billing subscription.
300async fn create_billing_subscription(
301 Extension(app): Extension<Arc<AppState>>,
302 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
303) -> Result<Json<CreateBillingSubscriptionResponse>> {
304 let user = app
305 .db
306 .get_user_by_github_user_id(body.github_user_id)
307 .await?
308 .context("user not found")?;
309
310 let Some(stripe_billing) = app.stripe_billing.clone() else {
311 log::error!("failed to retrieve Stripe billing object");
312 Err(Error::http(
313 StatusCode::NOT_IMPLEMENTED,
314 "not supported".into(),
315 ))?
316 };
317
318 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
319 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
320 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
321
322 if !is_checkout_allowed {
323 return Err(Error::http(
324 StatusCode::CONFLICT,
325 "user already has an active subscription".into(),
326 ));
327 }
328 }
329
330 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
331 if let Some(existing_billing_customer) = &existing_billing_customer {
332 if existing_billing_customer.has_overdue_invoices {
333 return Err(Error::http(
334 StatusCode::PAYMENT_REQUIRED,
335 "user has overdue invoices".into(),
336 ));
337 }
338 }
339
340 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
341 CustomerId::from_str(&existing_customer.stripe_customer_id)
342 .context("failed to parse customer ID")?
343 } else {
344 stripe_billing
345 .find_or_create_customer_by_email(user.email_address.as_deref())
346 .await?
347 .try_into()?
348 };
349
350 let success_url = format!(
351 "{}/account?checkout_complete=1",
352 app.config.zed_dot_dev_url()
353 );
354
355 let checkout_session_url = match body.product {
356 ProductCode::ZedPro => {
357 stripe_billing
358 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
359 .await?
360 }
361 ProductCode::ZedProTrial => {
362 if let Some(existing_billing_customer) = &existing_billing_customer {
363 if existing_billing_customer.trial_started_at.is_some() {
364 return Err(Error::http(
365 StatusCode::FORBIDDEN,
366 "user already used free trial".into(),
367 ));
368 }
369 }
370
371 let feature_flags = app.db.get_user_flags(user.id).await?;
372
373 stripe_billing
374 .checkout_with_zed_pro_trial(
375 customer_id,
376 &user.github_login,
377 feature_flags,
378 &success_url,
379 )
380 .await?
381 }
382 ProductCode::ZedFree => {
383 stripe_billing
384 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
385 .await?
386 }
387 };
388
389 Ok(Json(CreateBillingSubscriptionResponse {
390 checkout_session_url,
391 }))
392}
393
394#[derive(Debug, PartialEq, Deserialize)]
395#[serde(rename_all = "snake_case")]
396enum ManageSubscriptionIntent {
397 /// The user intends to manage their subscription.
398 ///
399 /// This will open the Stripe billing portal without putting the user in a specific flow.
400 ManageSubscription,
401 /// The user intends to update their payment method.
402 UpdatePaymentMethod,
403 /// The user intends to upgrade to Zed Pro.
404 UpgradeToPro,
405 /// The user intends to cancel their subscription.
406 Cancel,
407 /// The user intends to stop the cancellation of their subscription.
408 StopCancellation,
409}
410
411#[derive(Debug, Deserialize)]
412struct ManageBillingSubscriptionBody {
413 github_user_id: i32,
414 intent: ManageSubscriptionIntent,
415 /// The ID of the subscription to manage.
416 subscription_id: BillingSubscriptionId,
417 redirect_to: Option<String>,
418}
419
420#[derive(Debug, Serialize)]
421struct ManageBillingSubscriptionResponse {
422 billing_portal_session_url: Option<String>,
423}
424
425/// Initiates a Stripe customer portal session for managing a billing subscription.
426async fn manage_billing_subscription(
427 Extension(app): Extension<Arc<AppState>>,
428 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
429) -> Result<Json<ManageBillingSubscriptionResponse>> {
430 let user = app
431 .db
432 .get_user_by_github_user_id(body.github_user_id)
433 .await?
434 .context("user not found")?;
435
436 let Some(stripe_client) = app.stripe_client.clone() else {
437 log::error!("failed to retrieve Stripe client");
438 Err(Error::http(
439 StatusCode::NOT_IMPLEMENTED,
440 "not supported".into(),
441 ))?
442 };
443
444 let Some(stripe_billing) = app.stripe_billing.clone() else {
445 log::error!("failed to retrieve Stripe billing object");
446 Err(Error::http(
447 StatusCode::NOT_IMPLEMENTED,
448 "not supported".into(),
449 ))?
450 };
451
452 let customer = app
453 .db
454 .get_billing_customer_by_user_id(user.id)
455 .await?
456 .context("billing customer not found")?;
457 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
458 .context("failed to parse customer ID")?;
459
460 let subscription = app
461 .db
462 .get_billing_subscription_by_id(body.subscription_id)
463 .await?
464 .context("subscription not found")?;
465 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
466 .context("failed to parse subscription ID")?;
467
468 if body.intent == ManageSubscriptionIntent::StopCancellation {
469 let updated_stripe_subscription = Subscription::update(
470 &stripe_client,
471 &subscription_id,
472 stripe::UpdateSubscription {
473 cancel_at_period_end: Some(false),
474 ..Default::default()
475 },
476 )
477 .await?;
478
479 app.db
480 .update_billing_subscription(
481 subscription.id,
482 &UpdateBillingSubscriptionParams {
483 stripe_cancel_at: ActiveValue::set(
484 updated_stripe_subscription
485 .cancel_at
486 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
487 .map(|time| time.naive_utc()),
488 ),
489 ..Default::default()
490 },
491 )
492 .await?;
493
494 return Ok(Json(ManageBillingSubscriptionResponse {
495 billing_portal_session_url: None,
496 }));
497 }
498
499 let flow = match body.intent {
500 ManageSubscriptionIntent::ManageSubscription => None,
501 ManageSubscriptionIntent::UpgradeToPro => {
502 let zed_pro_price_id: stripe::PriceId =
503 stripe_billing.zed_pro_price_id().await?.try_into()?;
504 let zed_free_price_id: stripe::PriceId =
505 stripe_billing.zed_free_price_id().await?.try_into()?;
506
507 let stripe_subscription =
508 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
509
510 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
511 && stripe_subscription.items.data.iter().any(|item| {
512 item.price
513 .as_ref()
514 .map_or(false, |price| price.id == zed_pro_price_id)
515 });
516 if is_on_zed_pro_trial {
517 let payment_methods = PaymentMethod::list(
518 &stripe_client,
519 &stripe::ListPaymentMethods {
520 customer: Some(stripe_subscription.customer.id()),
521 ..Default::default()
522 },
523 )
524 .await?;
525
526 let has_payment_method = !payment_methods.data.is_empty();
527 if !has_payment_method {
528 return Err(Error::http(
529 StatusCode::BAD_REQUEST,
530 "missing payment method".into(),
531 ));
532 }
533
534 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
535 Subscription::update(
536 &stripe_client,
537 &stripe_subscription.id,
538 stripe::UpdateSubscription {
539 trial_end: Some(stripe::Scheduled::now()),
540 ..Default::default()
541 },
542 )
543 .await?;
544
545 return Ok(Json(ManageBillingSubscriptionResponse {
546 billing_portal_session_url: None,
547 }));
548 }
549
550 let subscription_item_to_update = stripe_subscription
551 .items
552 .data
553 .iter()
554 .find_map(|item| {
555 let price = item.price.as_ref()?;
556
557 if price.id == zed_free_price_id {
558 Some(item.id.clone())
559 } else {
560 None
561 }
562 })
563 .context("No subscription item to update")?;
564
565 Some(CreateBillingPortalSessionFlowData {
566 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
567 subscription_update_confirm: Some(
568 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
569 subscription: subscription.stripe_subscription_id,
570 items: vec![
571 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
572 id: subscription_item_to_update.to_string(),
573 price: Some(zed_pro_price_id.to_string()),
574 quantity: Some(1),
575 },
576 ],
577 discounts: None,
578 },
579 ),
580 ..Default::default()
581 })
582 }
583 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
584 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
585 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
586 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
587 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
588 return_url: format!(
589 "{}{path}",
590 app.config.zed_dot_dev_url(),
591 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
592 ),
593 }),
594 ..Default::default()
595 }),
596 ..Default::default()
597 }),
598 ManageSubscriptionIntent::Cancel => {
599 if subscription.kind == Some(SubscriptionKind::ZedFree) {
600 return Err(Error::http(
601 StatusCode::BAD_REQUEST,
602 "free subscription cannot be canceled".into(),
603 ));
604 }
605
606 Some(CreateBillingPortalSessionFlowData {
607 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
608 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
609 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
610 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
611 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
612 }),
613 ..Default::default()
614 }),
615 subscription_cancel: Some(
616 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
617 subscription: subscription.stripe_subscription_id,
618 retention: None,
619 },
620 ),
621 ..Default::default()
622 })
623 }
624 ManageSubscriptionIntent::StopCancellation => unreachable!(),
625 };
626
627 let mut params = CreateBillingPortalSession::new(customer_id);
628 params.flow_data = flow;
629 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
630 params.return_url = Some(&return_url);
631
632 let session = BillingPortalSession::create(&stripe_client, params).await?;
633
634 Ok(Json(ManageBillingSubscriptionResponse {
635 billing_portal_session_url: Some(session.url),
636 }))
637}
638
639#[derive(Debug, Deserialize)]
640struct MigrateToNewBillingBody {
641 github_user_id: i32,
642}
643
644#[derive(Debug, Serialize)]
645struct MigrateToNewBillingResponse {
646 /// The ID of the subscription that was canceled.
647 canceled_subscription_id: Option<String>,
648}
649
650async fn migrate_to_new_billing(
651 Extension(app): Extension<Arc<AppState>>,
652 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
653) -> Result<Json<MigrateToNewBillingResponse>> {
654 let Some(stripe_client) = app.stripe_client.clone() else {
655 log::error!("failed to retrieve Stripe client");
656 Err(Error::http(
657 StatusCode::NOT_IMPLEMENTED,
658 "not supported".into(),
659 ))?
660 };
661
662 let user = app
663 .db
664 .get_user_by_github_user_id(body.github_user_id)
665 .await?
666 .context("user not found")?;
667
668 let old_billing_subscriptions_by_user = app
669 .db
670 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
671 .await?;
672
673 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
674 old_billing_subscriptions_by_user.get(&user.id)
675 {
676 let stripe_subscription_id = billing_subscription
677 .stripe_subscription_id
678 .parse::<stripe::SubscriptionId>()
679 .context("failed to parse Stripe subscription ID from database")?;
680
681 Subscription::cancel(
682 &stripe_client,
683 &stripe_subscription_id,
684 stripe::CancelSubscription {
685 invoice_now: Some(true),
686 ..Default::default()
687 },
688 )
689 .await?;
690
691 Some(stripe_subscription_id)
692 } else {
693 None
694 };
695
696 let all_feature_flags = app.db.list_feature_flags().await?;
697 let user_feature_flags = app.db.get_user_flags(user.id).await?;
698
699 for feature_flag in ["new-billing", "assistant2"] {
700 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
701 if already_in_feature_flag {
702 continue;
703 }
704
705 let feature_flag = all_feature_flags
706 .iter()
707 .find(|flag| flag.flag == feature_flag)
708 .context("failed to find feature flag: {feature_flag:?}")?;
709
710 app.db.add_user_flag(user.id, feature_flag.id).await?;
711 }
712
713 Ok(Json(MigrateToNewBillingResponse {
714 canceled_subscription_id: canceled_subscription_id
715 .map(|subscription_id| subscription_id.to_string()),
716 }))
717}
718
719#[derive(Debug, Deserialize)]
720struct SyncBillingSubscriptionBody {
721 github_user_id: i32,
722}
723
724#[derive(Debug, Serialize)]
725struct SyncBillingSubscriptionResponse {
726 stripe_customer_id: String,
727}
728
729async fn sync_billing_subscription(
730 Extension(app): Extension<Arc<AppState>>,
731 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
732) -> Result<Json<SyncBillingSubscriptionResponse>> {
733 let Some(stripe_client) = app.stripe_client.clone() else {
734 log::error!("failed to retrieve Stripe client");
735 Err(Error::http(
736 StatusCode::NOT_IMPLEMENTED,
737 "not supported".into(),
738 ))?
739 };
740
741 let user = app
742 .db
743 .get_user_by_github_user_id(body.github_user_id)
744 .await?
745 .context("user not found")?;
746
747 let billing_customer = app
748 .db
749 .get_billing_customer_by_user_id(user.id)
750 .await?
751 .context("billing customer not found")?;
752 let stripe_customer_id = billing_customer
753 .stripe_customer_id
754 .parse::<stripe::CustomerId>()
755 .context("failed to parse Stripe customer ID from database")?;
756
757 let subscriptions = Subscription::list(
758 &stripe_client,
759 &stripe::ListSubscriptions {
760 customer: Some(stripe_customer_id),
761 // Sync all non-canceled subscriptions.
762 status: None,
763 ..Default::default()
764 },
765 )
766 .await?;
767
768 for subscription in subscriptions.data {
769 let subscription_id = subscription.id.clone();
770
771 sync_subscription(&app, &stripe_client, subscription)
772 .await
773 .with_context(|| {
774 format!(
775 "failed to sync subscription {subscription_id} for user {}",
776 user.id,
777 )
778 })?;
779 }
780
781 Ok(Json(SyncBillingSubscriptionResponse {
782 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
783 }))
784}
785
786/// The amount of time we wait in between each poll of Stripe events.
787///
788/// This value should strike a balance between:
789/// 1. Being short enough that we update quickly when something in Stripe changes
790/// 2. Being long enough that we don't eat into our rate limits.
791///
792/// As a point of reference, the Sequin folks say they have this at **500ms**:
793///
794/// > We poll the Stripe /events endpoint every 500ms per account
795/// >
796/// > — https://blog.sequinstream.com/events-not-webhooks/
797const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
798
799/// The maximum number of events to return per page.
800///
801/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
802///
803/// > Limit can range between 1 and 100, and the default is 10.
804const EVENTS_LIMIT_PER_PAGE: u64 = 100;
805
806/// The number of pages consisting entirely of already-processed events that we
807/// will see before we stop retrieving events.
808///
809/// This is used to prevent over-fetching the Stripe events API for events we've
810/// already seen and processed.
811const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
812
813/// Polls the Stripe events API periodically to reconcile the records in our
814/// database with the data in Stripe.
815pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
816 let Some(stripe_client) = app.stripe_client.clone() else {
817 log::warn!("failed to retrieve Stripe client");
818 return;
819 };
820
821 let executor = app.executor.clone();
822 executor.spawn_detached({
823 let executor = executor.clone();
824 async move {
825 loop {
826 poll_stripe_events(&app, &rpc_server, &stripe_client)
827 .await
828 .log_err();
829
830 executor.sleep(POLL_EVENTS_INTERVAL).await;
831 }
832 }
833 });
834}
835
836async fn poll_stripe_events(
837 app: &Arc<AppState>,
838 rpc_server: &Arc<Server>,
839 stripe_client: &stripe::Client,
840) -> anyhow::Result<()> {
841 fn event_type_to_string(event_type: EventType) -> String {
842 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
843 // so we need to unquote it.
844 event_type.to_string().trim_matches('"').to_string()
845 }
846
847 let event_types = [
848 EventType::CustomerCreated,
849 EventType::CustomerUpdated,
850 EventType::CustomerSubscriptionCreated,
851 EventType::CustomerSubscriptionUpdated,
852 EventType::CustomerSubscriptionPaused,
853 EventType::CustomerSubscriptionResumed,
854 EventType::CustomerSubscriptionDeleted,
855 ]
856 .into_iter()
857 .map(event_type_to_string)
858 .collect::<Vec<_>>();
859
860 let mut pages_of_already_processed_events = 0;
861 let mut unprocessed_events = Vec::new();
862
863 log::info!(
864 "Stripe events: starting retrieval for {}",
865 event_types.join(", ")
866 );
867 let mut params = ListEvents::new();
868 params.types = Some(event_types.clone());
869 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
870
871 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
872 .await?
873 .paginate(params);
874
875 loop {
876 let processed_event_ids = {
877 let event_ids = event_pages
878 .page
879 .data
880 .iter()
881 .map(|event| event.id.as_str())
882 .collect::<Vec<_>>();
883 app.db
884 .get_processed_stripe_events_by_event_ids(&event_ids)
885 .await?
886 .into_iter()
887 .map(|event| event.stripe_event_id)
888 .collect::<Vec<_>>()
889 };
890
891 let mut processed_events_in_page = 0;
892 let events_in_page = event_pages.page.data.len();
893 for event in &event_pages.page.data {
894 if processed_event_ids.contains(&event.id.to_string()) {
895 processed_events_in_page += 1;
896 log::debug!("Stripe events: already processed '{}', skipping", event.id);
897 } else {
898 unprocessed_events.push(event.clone());
899 }
900 }
901
902 if processed_events_in_page == events_in_page {
903 pages_of_already_processed_events += 1;
904 }
905
906 if event_pages.page.has_more {
907 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
908 {
909 log::info!(
910 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
911 );
912 break;
913 } else {
914 log::info!("Stripe events: retrieving next page");
915 event_pages = event_pages.next(&stripe_client).await?;
916 }
917 } else {
918 break;
919 }
920 }
921
922 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
923
924 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
925 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
926
927 for event in unprocessed_events {
928 let event_id = event.id.clone();
929 let processed_event_params = CreateProcessedStripeEventParams {
930 stripe_event_id: event.id.to_string(),
931 stripe_event_type: event_type_to_string(event.type_),
932 stripe_event_created_timestamp: event.created,
933 };
934
935 // If the event has happened too far in the past, we don't want to
936 // process it and risk overwriting other more-recent updates.
937 //
938 // 1 day was chosen arbitrarily. This could be made longer or shorter.
939 let one_day = Duration::from_secs(24 * 60 * 60);
940 let a_day_ago = Utc::now() - one_day;
941 if a_day_ago.timestamp() > event.created {
942 log::info!(
943 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
944 event_id
945 );
946 app.db
947 .create_processed_stripe_event(&processed_event_params)
948 .await?;
949
950 continue;
951 }
952
953 let process_result = match event.type_ {
954 EventType::CustomerCreated | EventType::CustomerUpdated => {
955 handle_customer_event(app, stripe_client, event).await
956 }
957 EventType::CustomerSubscriptionCreated
958 | EventType::CustomerSubscriptionUpdated
959 | EventType::CustomerSubscriptionPaused
960 | EventType::CustomerSubscriptionResumed
961 | EventType::CustomerSubscriptionDeleted => {
962 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
963 }
964 _ => Ok(()),
965 };
966
967 if let Some(()) = process_result
968 .with_context(|| format!("failed to process event {event_id} successfully"))
969 .log_err()
970 {
971 app.db
972 .create_processed_stripe_event(&processed_event_params)
973 .await?;
974 }
975 }
976
977 Ok(())
978}
979
980async fn handle_customer_event(
981 app: &Arc<AppState>,
982 _stripe_client: &stripe::Client,
983 event: stripe::Event,
984) -> anyhow::Result<()> {
985 let EventObject::Customer(customer) = event.data.object else {
986 bail!("unexpected event payload for {}", event.id);
987 };
988
989 log::info!("handling Stripe {} event: {}", event.type_, event.id);
990
991 let Some(email) = customer.email else {
992 log::info!("Stripe customer has no email: skipping");
993 return Ok(());
994 };
995
996 let Some(user) = app.db.get_user_by_email(&email).await? else {
997 log::info!("no user found for email: skipping");
998 return Ok(());
999 };
1000
1001 if let Some(existing_customer) = app
1002 .db
1003 .get_billing_customer_by_stripe_customer_id(&customer.id)
1004 .await?
1005 {
1006 app.db
1007 .update_billing_customer(
1008 existing_customer.id,
1009 &UpdateBillingCustomerParams {
1010 // For now we just leave the information as-is, as it is not
1011 // likely to change.
1012 ..Default::default()
1013 },
1014 )
1015 .await?;
1016 } else {
1017 app.db
1018 .create_billing_customer(&CreateBillingCustomerParams {
1019 user_id: user.id,
1020 stripe_customer_id: customer.id.to_string(),
1021 })
1022 .await?;
1023 }
1024
1025 Ok(())
1026}
1027
1028async fn sync_subscription(
1029 app: &Arc<AppState>,
1030 stripe_client: &stripe::Client,
1031 subscription: stripe::Subscription,
1032) -> anyhow::Result<billing_customer::Model> {
1033 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1034 stripe_billing
1035 .determine_subscription_kind(&subscription)
1036 .await
1037 } else {
1038 None
1039 };
1040
1041 let billing_customer =
1042 find_or_create_billing_customer(app, stripe_client, subscription.customer)
1043 .await?
1044 .context("billing customer not found")?;
1045
1046 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1047 if subscription.status == SubscriptionStatus::Trialing {
1048 let current_period_start =
1049 DateTime::from_timestamp(subscription.current_period_start, 0)
1050 .context("No trial subscription period start")?;
1051
1052 app.db
1053 .update_billing_customer(
1054 billing_customer.id,
1055 &UpdateBillingCustomerParams {
1056 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1057 ..Default::default()
1058 },
1059 )
1060 .await?;
1061 }
1062 }
1063
1064 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1065 && subscription
1066 .cancellation_details
1067 .as_ref()
1068 .and_then(|details| details.reason)
1069 .map_or(false, |reason| {
1070 reason == CancellationDetailsReason::PaymentFailed
1071 });
1072
1073 if was_canceled_due_to_payment_failure {
1074 app.db
1075 .update_billing_customer(
1076 billing_customer.id,
1077 &UpdateBillingCustomerParams {
1078 has_overdue_invoices: ActiveValue::set(true),
1079 ..Default::default()
1080 },
1081 )
1082 .await?;
1083 }
1084
1085 if let Some(existing_subscription) = app
1086 .db
1087 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1088 .await?
1089 {
1090 app.db
1091 .update_billing_subscription(
1092 existing_subscription.id,
1093 &UpdateBillingSubscriptionParams {
1094 billing_customer_id: ActiveValue::set(billing_customer.id),
1095 kind: ActiveValue::set(subscription_kind),
1096 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1097 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1098 stripe_cancel_at: ActiveValue::set(
1099 subscription
1100 .cancel_at
1101 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1102 .map(|time| time.naive_utc()),
1103 ),
1104 stripe_cancellation_reason: ActiveValue::set(
1105 subscription
1106 .cancellation_details
1107 .and_then(|details| details.reason)
1108 .map(|reason| reason.into()),
1109 ),
1110 stripe_current_period_start: ActiveValue::set(Some(
1111 subscription.current_period_start,
1112 )),
1113 stripe_current_period_end: ActiveValue::set(Some(
1114 subscription.current_period_end,
1115 )),
1116 },
1117 )
1118 .await?;
1119 } else {
1120 if let Some(existing_subscription) = app
1121 .db
1122 .get_active_billing_subscription(billing_customer.user_id)
1123 .await?
1124 {
1125 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1126 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1127 {
1128 let stripe_subscription_id = existing_subscription
1129 .stripe_subscription_id
1130 .parse::<stripe::SubscriptionId>()
1131 .context("failed to parse Stripe subscription ID from database")?;
1132
1133 Subscription::cancel(
1134 &stripe_client,
1135 &stripe_subscription_id,
1136 stripe::CancelSubscription {
1137 invoice_now: None,
1138 ..Default::default()
1139 },
1140 )
1141 .await?;
1142 } else {
1143 // If the user already has an active billing subscription, ignore the
1144 // event and return an `Ok` to signal that it was processed
1145 // successfully.
1146 //
1147 // There is the possibility that this could cause us to not create a
1148 // subscription in the following scenario:
1149 //
1150 // 1. User has an active subscription A
1151 // 2. User cancels subscription A
1152 // 3. User creates a new subscription B
1153 // 4. We process the new subscription B before the cancellation of subscription A
1154 // 5. User ends up with no subscriptions
1155 //
1156 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1157
1158 log::info!(
1159 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1160 user_id = billing_customer.user_id,
1161 subscription_id = subscription.id
1162 );
1163 return Ok(billing_customer);
1164 }
1165 }
1166
1167 app.db
1168 .create_billing_subscription(&CreateBillingSubscriptionParams {
1169 billing_customer_id: billing_customer.id,
1170 kind: subscription_kind,
1171 stripe_subscription_id: subscription.id.to_string(),
1172 stripe_subscription_status: subscription.status.into(),
1173 stripe_cancellation_reason: subscription
1174 .cancellation_details
1175 .and_then(|details| details.reason)
1176 .map(|reason| reason.into()),
1177 stripe_current_period_start: Some(subscription.current_period_start),
1178 stripe_current_period_end: Some(subscription.current_period_end),
1179 })
1180 .await?;
1181 }
1182
1183 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1184 if subscription.status == SubscriptionStatus::Canceled
1185 || subscription.status == SubscriptionStatus::Paused
1186 {
1187 let already_has_active_billing_subscription = app
1188 .db
1189 .has_active_billing_subscription(billing_customer.user_id)
1190 .await?;
1191 if !already_has_active_billing_subscription {
1192 let stripe_customer_id = billing_customer
1193 .stripe_customer_id
1194 .parse::<stripe::CustomerId>()
1195 .context("failed to parse Stripe customer ID from database")?;
1196
1197 stripe_billing
1198 .subscribe_to_zed_free(stripe_customer_id)
1199 .await?;
1200 }
1201 }
1202 }
1203
1204 Ok(billing_customer)
1205}
1206
1207async fn handle_customer_subscription_event(
1208 app: &Arc<AppState>,
1209 rpc_server: &Arc<Server>,
1210 stripe_client: &stripe::Client,
1211 event: stripe::Event,
1212) -> anyhow::Result<()> {
1213 let EventObject::Subscription(subscription) = event.data.object else {
1214 bail!("unexpected event payload for {}", event.id);
1215 };
1216
1217 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1218
1219 let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1220
1221 // When the user's subscription changes, push down any changes to their plan.
1222 rpc_server
1223 .update_plan_for_user(billing_customer.user_id)
1224 .await
1225 .trace_err();
1226
1227 // When the user's subscription changes, we want to refresh their LLM tokens
1228 // to either grant/revoke access.
1229 rpc_server
1230 .refresh_llm_tokens_for_user(billing_customer.user_id)
1231 .await;
1232
1233 Ok(())
1234}
1235
1236#[derive(Debug, Deserialize)]
1237struct GetCurrentUsageParams {
1238 github_user_id: i32,
1239}
1240
1241#[derive(Debug, Serialize)]
1242struct UsageCounts {
1243 pub used: i32,
1244 pub limit: Option<i32>,
1245 pub remaining: Option<i32>,
1246}
1247
1248#[derive(Debug, Serialize)]
1249struct ModelRequestUsage {
1250 pub model: String,
1251 pub mode: CompletionMode,
1252 pub requests: i32,
1253}
1254
1255#[derive(Debug, Serialize)]
1256struct CurrentUsage {
1257 pub model_requests: UsageCounts,
1258 pub model_request_usage: Vec<ModelRequestUsage>,
1259 pub edit_predictions: UsageCounts,
1260}
1261
1262#[derive(Debug, Default, Serialize)]
1263struct GetCurrentUsageResponse {
1264 pub plan: String,
1265 pub current_usage: Option<CurrentUsage>,
1266}
1267
1268async fn get_current_usage(
1269 Extension(app): Extension<Arc<AppState>>,
1270 Query(params): Query<GetCurrentUsageParams>,
1271) -> Result<Json<GetCurrentUsageResponse>> {
1272 let user = app
1273 .db
1274 .get_user_by_github_user_id(params.github_user_id)
1275 .await?
1276 .context("user not found")?;
1277
1278 let feature_flags = app.db.get_user_flags(user.id).await?;
1279 let has_extended_trial = feature_flags
1280 .iter()
1281 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1282
1283 let Some(llm_db) = app.llm_db.clone() else {
1284 return Err(Error::http(
1285 StatusCode::NOT_IMPLEMENTED,
1286 "LLM database not available".into(),
1287 ));
1288 };
1289
1290 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1291 return Ok(Json(GetCurrentUsageResponse::default()));
1292 };
1293
1294 let subscription_period = maybe!({
1295 let period_start_at = subscription.current_period_start_at()?;
1296 let period_end_at = subscription.current_period_end_at()?;
1297
1298 Some((period_start_at, period_end_at))
1299 });
1300
1301 let Some((period_start_at, period_end_at)) = subscription_period else {
1302 return Ok(Json(GetCurrentUsageResponse::default()));
1303 };
1304
1305 let usage = llm_db
1306 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1307 .await?;
1308
1309 let plan = subscription
1310 .kind
1311 .map(Into::into)
1312 .unwrap_or(zed_llm_client::Plan::ZedFree);
1313
1314 let model_requests_limit = match plan.model_requests_limit() {
1315 zed_llm_client::UsageLimit::Limited(limit) => {
1316 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1317 1_000
1318 } else {
1319 limit
1320 };
1321
1322 Some(limit)
1323 }
1324 zed_llm_client::UsageLimit::Unlimited => None,
1325 };
1326
1327 let edit_predictions_limit = match plan.edit_predictions_limit() {
1328 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1329 zed_llm_client::UsageLimit::Unlimited => None,
1330 };
1331
1332 let Some(usage) = usage else {
1333 return Ok(Json(GetCurrentUsageResponse {
1334 plan: plan.as_str().to_string(),
1335 current_usage: Some(CurrentUsage {
1336 model_requests: UsageCounts {
1337 used: 0,
1338 limit: model_requests_limit,
1339 remaining: model_requests_limit,
1340 },
1341 model_request_usage: Vec::new(),
1342 edit_predictions: UsageCounts {
1343 used: 0,
1344 limit: edit_predictions_limit,
1345 remaining: edit_predictions_limit,
1346 },
1347 }),
1348 }));
1349 };
1350
1351 let subscription_usage_meters = llm_db
1352 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1353 .await?;
1354
1355 let model_request_usage = subscription_usage_meters
1356 .into_iter()
1357 .filter_map(|(usage_meter, _usage)| {
1358 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1359
1360 Some(ModelRequestUsage {
1361 model: model.name.clone(),
1362 mode: usage_meter.mode,
1363 requests: usage_meter.requests,
1364 })
1365 })
1366 .collect::<Vec<_>>();
1367
1368 Ok(Json(GetCurrentUsageResponse {
1369 plan: plan.as_str().to_string(),
1370 current_usage: Some(CurrentUsage {
1371 model_requests: UsageCounts {
1372 used: usage.model_requests,
1373 limit: model_requests_limit,
1374 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1375 },
1376 model_request_usage,
1377 edit_predictions: UsageCounts {
1378 used: usage.edit_predictions,
1379 limit: edit_predictions_limit,
1380 remaining: edit_predictions_limit
1381 .map(|limit| (limit - usage.edit_predictions).max(0)),
1382 },
1383 }),
1384 }))
1385}
1386
1387impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1388 fn from(value: SubscriptionStatus) -> Self {
1389 match value {
1390 SubscriptionStatus::Incomplete => Self::Incomplete,
1391 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1392 SubscriptionStatus::Trialing => Self::Trialing,
1393 SubscriptionStatus::Active => Self::Active,
1394 SubscriptionStatus::PastDue => Self::PastDue,
1395 SubscriptionStatus::Canceled => Self::Canceled,
1396 SubscriptionStatus::Unpaid => Self::Unpaid,
1397 SubscriptionStatus::Paused => Self::Paused,
1398 }
1399 }
1400}
1401
1402impl From<CancellationDetailsReason> for StripeCancellationReason {
1403 fn from(value: CancellationDetailsReason) -> Self {
1404 match value {
1405 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1406 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1407 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1408 }
1409 }
1410}
1411
1412/// Finds or creates a billing customer using the provided customer.
1413pub async fn find_or_create_billing_customer(
1414 app: &Arc<AppState>,
1415 stripe_client: &stripe::Client,
1416 customer_or_id: Expandable<Customer>,
1417) -> anyhow::Result<Option<billing_customer::Model>> {
1418 let customer_id = match &customer_or_id {
1419 Expandable::Id(id) => id,
1420 Expandable::Object(customer) => customer.id.as_ref(),
1421 };
1422
1423 // If we already have a billing customer record associated with the Stripe customer,
1424 // there's nothing more we need to do.
1425 if let Some(billing_customer) = app
1426 .db
1427 .get_billing_customer_by_stripe_customer_id(customer_id)
1428 .await?
1429 {
1430 return Ok(Some(billing_customer));
1431 }
1432
1433 // If all we have is a customer ID, resolve it to a full customer record by
1434 // hitting the Stripe API.
1435 let customer = match customer_or_id {
1436 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1437 Expandable::Object(customer) => *customer,
1438 };
1439
1440 let Some(email) = customer.email else {
1441 return Ok(None);
1442 };
1443
1444 let Some(user) = app.db.get_user_by_email(&email).await? else {
1445 return Ok(None);
1446 };
1447
1448 let billing_customer = app
1449 .db
1450 .create_billing_customer(&CreateBillingCustomerParams {
1451 user_id: user.id,
1452 stripe_customer_id: customer.id.to_string(),
1453 })
1454 .await?;
1455
1456 Ok(Some(billing_customer))
1457}
1458
1459const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1460
1461pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1462 let Some(stripe_billing) = app.stripe_billing.clone() else {
1463 log::warn!("failed to retrieve Stripe billing object");
1464 return;
1465 };
1466 let Some(llm_db) = app.llm_db.clone() else {
1467 log::warn!("failed to retrieve LLM database");
1468 return;
1469 };
1470
1471 let executor = app.executor.clone();
1472 executor.spawn_detached({
1473 let executor = executor.clone();
1474 async move {
1475 loop {
1476 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1477 .await
1478 .context("failed to sync LLM request usage to Stripe")
1479 .trace_err();
1480 executor
1481 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1482 .await;
1483 }
1484 }
1485 });
1486}
1487
1488async fn sync_model_request_usage_with_stripe(
1489 app: &Arc<AppState>,
1490 llm_db: &Arc<LlmDatabase>,
1491 stripe_billing: &Arc<StripeBilling>,
1492) -> anyhow::Result<()> {
1493 let staff_users = app.db.get_staff_users().await?;
1494 let staff_user_ids = staff_users
1495 .iter()
1496 .map(|user| user.id)
1497 .collect::<HashSet<UserId>>();
1498
1499 let usage_meters = llm_db
1500 .get_current_subscription_usage_meters(Utc::now())
1501 .await?;
1502 let usage_meters = usage_meters
1503 .into_iter()
1504 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1505 .collect::<Vec<_>>();
1506 let user_ids = usage_meters
1507 .iter()
1508 .map(|(_, usage)| usage.user_id)
1509 .collect::<HashSet<UserId>>();
1510 let billing_subscriptions = app
1511 .db
1512 .get_active_zed_pro_billing_subscriptions(user_ids)
1513 .await?;
1514
1515 let claude_sonnet_4 = stripe_billing
1516 .find_price_by_lookup_key("claude-sonnet-4-requests")
1517 .await?;
1518 let claude_sonnet_4_max = stripe_billing
1519 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1520 .await?;
1521 let claude_opus_4 = stripe_billing
1522 .find_price_by_lookup_key("claude-opus-4-requests")
1523 .await?;
1524 let claude_opus_4_max = stripe_billing
1525 .find_price_by_lookup_key("claude-opus-4-requests-max")
1526 .await?;
1527 let claude_3_5_sonnet = stripe_billing
1528 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1529 .await?;
1530 let claude_3_7_sonnet = stripe_billing
1531 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1532 .await?;
1533 let claude_3_7_sonnet_max = stripe_billing
1534 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1535 .await?;
1536
1537 for (usage_meter, usage) in usage_meters {
1538 maybe!(async {
1539 let Some((billing_customer, billing_subscription)) =
1540 billing_subscriptions.get(&usage.user_id)
1541 else {
1542 bail!(
1543 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1544 usage.user_id
1545 );
1546 };
1547
1548 let stripe_customer_id = billing_customer
1549 .stripe_customer_id
1550 .parse::<stripe::CustomerId>()
1551 .context("failed to parse Stripe customer ID from database")?;
1552 let stripe_subscription_id = billing_subscription
1553 .stripe_subscription_id
1554 .parse::<stripe::SubscriptionId>()
1555 .context("failed to parse Stripe subscription ID from database")?;
1556
1557 let model = llm_db.model_by_id(usage_meter.model_id)?;
1558
1559 let (price, meter_event_name) = match model.name.as_str() {
1560 "claude-opus-4" => match usage_meter.mode {
1561 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1562 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1563 },
1564 "claude-sonnet-4" => match usage_meter.mode {
1565 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1566 CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
1567 },
1568 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1569 "claude-3-7-sonnet" => match usage_meter.mode {
1570 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1571 CompletionMode::Max => {
1572 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1573 }
1574 },
1575 model_name => {
1576 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1577 }
1578 };
1579
1580 stripe_billing
1581 .subscribe_to_price(&stripe_subscription_id.into(), price)
1582 .await?;
1583 stripe_billing
1584 .bill_model_request_usage(
1585 &stripe_customer_id,
1586 meter_event_name,
1587 usage_meter.requests,
1588 )
1589 .await?;
1590
1591 Ok(())
1592 })
1593 .await
1594 .log_err();
1595 }
1596
1597 Ok(())
1598}