1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
21 Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
31use crate::rpc::{ResultExt as _, Server};
32use crate::{AppState, Error, Result};
33use crate::{db::UserId, llm::db::LlmDatabase};
34use crate::{
35 db::{
36 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
37 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
38 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
39 },
40 stripe_billing::StripeBilling,
41};
42
43pub fn router() -> Router {
44 Router::new()
45 .route(
46 "/billing/preferences",
47 get(get_billing_preferences).put(update_billing_preferences),
48 )
49 .route(
50 "/billing/subscriptions",
51 get(list_billing_subscriptions).post(create_billing_subscription),
52 )
53 .route(
54 "/billing/subscriptions/manage",
55 post(manage_billing_subscription),
56 )
57 .route(
58 "/billing/subscriptions/migrate",
59 post(migrate_to_new_billing),
60 )
61 .route(
62 "/billing/subscriptions/sync",
63 post(sync_billing_subscription),
64 )
65 .route("/billing/usage", get(get_current_usage))
66}
67
68#[derive(Debug, Deserialize)]
69struct GetBillingPreferencesParams {
70 github_user_id: i32,
71}
72
73#[derive(Debug, Serialize)]
74struct BillingPreferencesResponse {
75 trial_started_at: Option<String>,
76 max_monthly_llm_usage_spending_in_cents: i32,
77 model_request_overages_enabled: bool,
78 model_request_overages_spend_limit_in_cents: i32,
79}
80
81async fn get_billing_preferences(
82 Extension(app): Extension<Arc<AppState>>,
83 Query(params): Query<GetBillingPreferencesParams>,
84) -> Result<Json<BillingPreferencesResponse>> {
85 let user = app
86 .db
87 .get_user_by_github_user_id(params.github_user_id)
88 .await?
89 .context("user not found")?;
90
91 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
92 let preferences = app.db.get_billing_preferences(user.id).await?;
93
94 Ok(Json(BillingPreferencesResponse {
95 trial_started_at: billing_customer
96 .and_then(|billing_customer| billing_customer.trial_started_at)
97 .map(|trial_started_at| {
98 trial_started_at
99 .and_utc()
100 .to_rfc3339_opts(SecondsFormat::Millis, true)
101 }),
102 max_monthly_llm_usage_spending_in_cents: preferences
103 .as_ref()
104 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
105 preferences.max_monthly_llm_usage_spending_in_cents
106 }),
107 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
108 preferences.model_request_overages_enabled
109 }),
110 model_request_overages_spend_limit_in_cents: preferences
111 .as_ref()
112 .map_or(0, |preferences| {
113 preferences.model_request_overages_spend_limit_in_cents
114 }),
115 }))
116}
117
118#[derive(Debug, Deserialize)]
119struct UpdateBillingPreferencesBody {
120 github_user_id: i32,
121 #[serde(default)]
122 max_monthly_llm_usage_spending_in_cents: i32,
123 #[serde(default)]
124 model_request_overages_enabled: bool,
125 #[serde(default)]
126 model_request_overages_spend_limit_in_cents: i32,
127}
128
129async fn update_billing_preferences(
130 Extension(app): Extension<Arc<AppState>>,
131 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
132 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
133) -> Result<Json<BillingPreferencesResponse>> {
134 let user = app
135 .db
136 .get_user_by_github_user_id(body.github_user_id)
137 .await?
138 .context("user not found")?;
139
140 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
141
142 let max_monthly_llm_usage_spending_in_cents =
143 body.max_monthly_llm_usage_spending_in_cents.max(0);
144 let model_request_overages_spend_limit_in_cents =
145 body.model_request_overages_spend_limit_in_cents.max(0);
146
147 let billing_preferences =
148 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
149 app.db
150 .update_billing_preferences(
151 user.id,
152 &UpdateBillingPreferencesParams {
153 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
154 max_monthly_llm_usage_spending_in_cents,
155 ),
156 model_request_overages_enabled: ActiveValue::set(
157 body.model_request_overages_enabled,
158 ),
159 model_request_overages_spend_limit_in_cents: ActiveValue::set(
160 model_request_overages_spend_limit_in_cents,
161 ),
162 },
163 )
164 .await?
165 } else {
166 app.db
167 .create_billing_preferences(
168 user.id,
169 &crate::db::CreateBillingPreferencesParams {
170 max_monthly_llm_usage_spending_in_cents,
171 model_request_overages_enabled: body.model_request_overages_enabled,
172 model_request_overages_spend_limit_in_cents,
173 },
174 )
175 .await?
176 };
177
178 SnowflakeRow::new(
179 "Billing Preferences Updated",
180 Some(user.metrics_id),
181 user.admin,
182 None,
183 json!({
184 "user_id": user.id,
185 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
186 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
187 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
188 }),
189 )
190 .write(&app.kinesis_client, &app.config.kinesis_stream)
191 .await
192 .log_err();
193
194 rpc_server.refresh_llm_tokens_for_user(user.id).await;
195
196 Ok(Json(BillingPreferencesResponse {
197 trial_started_at: billing_customer
198 .and_then(|billing_customer| billing_customer.trial_started_at)
199 .map(|trial_started_at| {
200 trial_started_at
201 .and_utc()
202 .to_rfc3339_opts(SecondsFormat::Millis, true)
203 }),
204 max_monthly_llm_usage_spending_in_cents: billing_preferences
205 .max_monthly_llm_usage_spending_in_cents,
206 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
207 model_request_overages_spend_limit_in_cents: billing_preferences
208 .model_request_overages_spend_limit_in_cents,
209 }))
210}
211
212#[derive(Debug, Deserialize)]
213struct ListBillingSubscriptionsParams {
214 github_user_id: i32,
215}
216
217#[derive(Debug, Serialize)]
218struct BillingSubscriptionJson {
219 id: BillingSubscriptionId,
220 name: String,
221 status: StripeSubscriptionStatus,
222 trial_end_at: Option<String>,
223 cancel_at: Option<String>,
224 /// Whether this subscription can be canceled.
225 is_cancelable: bool,
226}
227
228#[derive(Debug, Serialize)]
229struct ListBillingSubscriptionsResponse {
230 subscriptions: Vec<BillingSubscriptionJson>,
231}
232
233async fn list_billing_subscriptions(
234 Extension(app): Extension<Arc<AppState>>,
235 Query(params): Query<ListBillingSubscriptionsParams>,
236) -> Result<Json<ListBillingSubscriptionsResponse>> {
237 let user = app
238 .db
239 .get_user_by_github_user_id(params.github_user_id)
240 .await?
241 .context("user not found")?;
242
243 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
244
245 Ok(Json(ListBillingSubscriptionsResponse {
246 subscriptions: subscriptions
247 .into_iter()
248 .map(|subscription| BillingSubscriptionJson {
249 id: subscription.id,
250 name: match subscription.kind {
251 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
252 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
253 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
254 None => "Zed LLM Usage".to_string(),
255 },
256 status: subscription.stripe_subscription_status,
257 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
258 maybe!({
259 let end_at = subscription.stripe_current_period_end?;
260 let end_at = DateTime::from_timestamp(end_at, 0)?;
261
262 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
263 })
264 } else {
265 None
266 },
267 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
268 cancel_at
269 .and_utc()
270 .to_rfc3339_opts(SecondsFormat::Millis, true)
271 }),
272 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
273 && subscription.stripe_subscription_status.is_cancelable()
274 && subscription.stripe_cancel_at.is_none(),
275 })
276 .collect(),
277 }))
278}
279
280#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
281#[serde(rename_all = "snake_case")]
282enum ProductCode {
283 ZedPro,
284 ZedProTrial,
285 ZedFree,
286}
287
288#[derive(Debug, Deserialize)]
289struct CreateBillingSubscriptionBody {
290 github_user_id: i32,
291 product: ProductCode,
292}
293
294#[derive(Debug, Serialize)]
295struct CreateBillingSubscriptionResponse {
296 checkout_session_url: String,
297}
298
299/// Initiates a Stripe Checkout session for creating a billing subscription.
300async fn create_billing_subscription(
301 Extension(app): Extension<Arc<AppState>>,
302 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
303) -> Result<Json<CreateBillingSubscriptionResponse>> {
304 let user = app
305 .db
306 .get_user_by_github_user_id(body.github_user_id)
307 .await?
308 .context("user not found")?;
309
310 let Some(stripe_billing) = app.stripe_billing.clone() else {
311 log::error!("failed to retrieve Stripe billing object");
312 Err(Error::http(
313 StatusCode::NOT_IMPLEMENTED,
314 "not supported".into(),
315 ))?
316 };
317
318 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
319 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
320 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
321
322 if !is_checkout_allowed {
323 return Err(Error::http(
324 StatusCode::CONFLICT,
325 "user already has an active subscription".into(),
326 ));
327 }
328 }
329
330 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
331 if let Some(existing_billing_customer) = &existing_billing_customer {
332 if existing_billing_customer.has_overdue_invoices {
333 return Err(Error::http(
334 StatusCode::PAYMENT_REQUIRED,
335 "user has overdue invoices".into(),
336 ));
337 }
338 }
339
340 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
341 CustomerId::from_str(&existing_customer.stripe_customer_id)
342 .context("failed to parse customer ID")?
343 } else {
344 stripe_billing
345 .find_or_create_customer_by_email(user.email_address.as_deref())
346 .await?
347 };
348
349 let success_url = format!(
350 "{}/account?checkout_complete=1",
351 app.config.zed_dot_dev_url()
352 );
353
354 let checkout_session_url = match body.product {
355 ProductCode::ZedPro => {
356 stripe_billing
357 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
358 .await?
359 }
360 ProductCode::ZedProTrial => {
361 if let Some(existing_billing_customer) = &existing_billing_customer {
362 if existing_billing_customer.trial_started_at.is_some() {
363 return Err(Error::http(
364 StatusCode::FORBIDDEN,
365 "user already used free trial".into(),
366 ));
367 }
368 }
369
370 let feature_flags = app.db.get_user_flags(user.id).await?;
371
372 stripe_billing
373 .checkout_with_zed_pro_trial(
374 customer_id,
375 &user.github_login,
376 feature_flags,
377 &success_url,
378 )
379 .await?
380 }
381 ProductCode::ZedFree => {
382 stripe_billing
383 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
384 .await?
385 }
386 };
387
388 Ok(Json(CreateBillingSubscriptionResponse {
389 checkout_session_url,
390 }))
391}
392
393#[derive(Debug, PartialEq, Deserialize)]
394#[serde(rename_all = "snake_case")]
395enum ManageSubscriptionIntent {
396 /// The user intends to manage their subscription.
397 ///
398 /// This will open the Stripe billing portal without putting the user in a specific flow.
399 ManageSubscription,
400 /// The user intends to update their payment method.
401 UpdatePaymentMethod,
402 /// The user intends to upgrade to Zed Pro.
403 UpgradeToPro,
404 /// The user intends to cancel their subscription.
405 Cancel,
406 /// The user intends to stop the cancellation of their subscription.
407 StopCancellation,
408}
409
410#[derive(Debug, Deserialize)]
411struct ManageBillingSubscriptionBody {
412 github_user_id: i32,
413 intent: ManageSubscriptionIntent,
414 /// The ID of the subscription to manage.
415 subscription_id: BillingSubscriptionId,
416 redirect_to: Option<String>,
417}
418
419#[derive(Debug, Serialize)]
420struct ManageBillingSubscriptionResponse {
421 billing_portal_session_url: Option<String>,
422}
423
424/// Initiates a Stripe customer portal session for managing a billing subscription.
425async fn manage_billing_subscription(
426 Extension(app): Extension<Arc<AppState>>,
427 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
428) -> Result<Json<ManageBillingSubscriptionResponse>> {
429 let user = app
430 .db
431 .get_user_by_github_user_id(body.github_user_id)
432 .await?
433 .context("user not found")?;
434
435 let Some(stripe_client) = app.stripe_client.clone() else {
436 log::error!("failed to retrieve Stripe client");
437 Err(Error::http(
438 StatusCode::NOT_IMPLEMENTED,
439 "not supported".into(),
440 ))?
441 };
442
443 let Some(stripe_billing) = app.stripe_billing.clone() else {
444 log::error!("failed to retrieve Stripe billing object");
445 Err(Error::http(
446 StatusCode::NOT_IMPLEMENTED,
447 "not supported".into(),
448 ))?
449 };
450
451 let customer = app
452 .db
453 .get_billing_customer_by_user_id(user.id)
454 .await?
455 .context("billing customer not found")?;
456 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
457 .context("failed to parse customer ID")?;
458
459 let subscription = app
460 .db
461 .get_billing_subscription_by_id(body.subscription_id)
462 .await?
463 .context("subscription not found")?;
464 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
465 .context("failed to parse subscription ID")?;
466
467 if body.intent == ManageSubscriptionIntent::StopCancellation {
468 let updated_stripe_subscription = Subscription::update(
469 &stripe_client,
470 &subscription_id,
471 stripe::UpdateSubscription {
472 cancel_at_period_end: Some(false),
473 ..Default::default()
474 },
475 )
476 .await?;
477
478 app.db
479 .update_billing_subscription(
480 subscription.id,
481 &UpdateBillingSubscriptionParams {
482 stripe_cancel_at: ActiveValue::set(
483 updated_stripe_subscription
484 .cancel_at
485 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
486 .map(|time| time.naive_utc()),
487 ),
488 ..Default::default()
489 },
490 )
491 .await?;
492
493 return Ok(Json(ManageBillingSubscriptionResponse {
494 billing_portal_session_url: None,
495 }));
496 }
497
498 let flow = match body.intent {
499 ManageSubscriptionIntent::ManageSubscription => None,
500 ManageSubscriptionIntent::UpgradeToPro => {
501 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
502 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
503
504 let stripe_subscription =
505 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
506
507 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
508 && stripe_subscription.items.data.iter().any(|item| {
509 item.price
510 .as_ref()
511 .map_or(false, |price| price.id == zed_pro_price_id)
512 });
513 if is_on_zed_pro_trial {
514 let payment_methods = PaymentMethod::list(
515 &stripe_client,
516 &stripe::ListPaymentMethods {
517 customer: Some(stripe_subscription.customer.id()),
518 ..Default::default()
519 },
520 )
521 .await?;
522
523 let has_payment_method = !payment_methods.data.is_empty();
524 if !has_payment_method {
525 return Err(Error::http(
526 StatusCode::BAD_REQUEST,
527 "missing payment method".into(),
528 ));
529 }
530
531 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
532 Subscription::update(
533 &stripe_client,
534 &stripe_subscription.id,
535 stripe::UpdateSubscription {
536 trial_end: Some(stripe::Scheduled::now()),
537 ..Default::default()
538 },
539 )
540 .await?;
541
542 return Ok(Json(ManageBillingSubscriptionResponse {
543 billing_portal_session_url: None,
544 }));
545 }
546
547 let subscription_item_to_update = stripe_subscription
548 .items
549 .data
550 .iter()
551 .find_map(|item| {
552 let price = item.price.as_ref()?;
553
554 if price.id == zed_free_price_id {
555 Some(item.id.clone())
556 } else {
557 None
558 }
559 })
560 .context("No subscription item to update")?;
561
562 Some(CreateBillingPortalSessionFlowData {
563 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
564 subscription_update_confirm: Some(
565 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
566 subscription: subscription.stripe_subscription_id,
567 items: vec![
568 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
569 id: subscription_item_to_update.to_string(),
570 price: Some(zed_pro_price_id.to_string()),
571 quantity: Some(1),
572 },
573 ],
574 discounts: None,
575 },
576 ),
577 ..Default::default()
578 })
579 }
580 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
581 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
582 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
583 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
584 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
585 return_url: format!(
586 "{}{path}",
587 app.config.zed_dot_dev_url(),
588 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
589 ),
590 }),
591 ..Default::default()
592 }),
593 ..Default::default()
594 }),
595 ManageSubscriptionIntent::Cancel => {
596 if subscription.kind == Some(SubscriptionKind::ZedFree) {
597 return Err(Error::http(
598 StatusCode::BAD_REQUEST,
599 "free subscription cannot be canceled".into(),
600 ));
601 }
602
603 Some(CreateBillingPortalSessionFlowData {
604 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
605 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
606 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
607 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
608 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
609 }),
610 ..Default::default()
611 }),
612 subscription_cancel: Some(
613 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
614 subscription: subscription.stripe_subscription_id,
615 retention: None,
616 },
617 ),
618 ..Default::default()
619 })
620 }
621 ManageSubscriptionIntent::StopCancellation => unreachable!(),
622 };
623
624 let mut params = CreateBillingPortalSession::new(customer_id);
625 params.flow_data = flow;
626 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
627 params.return_url = Some(&return_url);
628
629 let session = BillingPortalSession::create(&stripe_client, params).await?;
630
631 Ok(Json(ManageBillingSubscriptionResponse {
632 billing_portal_session_url: Some(session.url),
633 }))
634}
635
636#[derive(Debug, Deserialize)]
637struct MigrateToNewBillingBody {
638 github_user_id: i32,
639}
640
641#[derive(Debug, Serialize)]
642struct MigrateToNewBillingResponse {
643 /// The ID of the subscription that was canceled.
644 canceled_subscription_id: Option<String>,
645}
646
647async fn migrate_to_new_billing(
648 Extension(app): Extension<Arc<AppState>>,
649 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
650) -> Result<Json<MigrateToNewBillingResponse>> {
651 let Some(stripe_client) = app.stripe_client.clone() else {
652 log::error!("failed to retrieve Stripe client");
653 Err(Error::http(
654 StatusCode::NOT_IMPLEMENTED,
655 "not supported".into(),
656 ))?
657 };
658
659 let user = app
660 .db
661 .get_user_by_github_user_id(body.github_user_id)
662 .await?
663 .context("user not found")?;
664
665 let old_billing_subscriptions_by_user = app
666 .db
667 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
668 .await?;
669
670 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
671 old_billing_subscriptions_by_user.get(&user.id)
672 {
673 let stripe_subscription_id = billing_subscription
674 .stripe_subscription_id
675 .parse::<stripe::SubscriptionId>()
676 .context("failed to parse Stripe subscription ID from database")?;
677
678 Subscription::cancel(
679 &stripe_client,
680 &stripe_subscription_id,
681 stripe::CancelSubscription {
682 invoice_now: Some(true),
683 ..Default::default()
684 },
685 )
686 .await?;
687
688 Some(stripe_subscription_id)
689 } else {
690 None
691 };
692
693 let all_feature_flags = app.db.list_feature_flags().await?;
694 let user_feature_flags = app.db.get_user_flags(user.id).await?;
695
696 for feature_flag in ["new-billing", "assistant2"] {
697 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
698 if already_in_feature_flag {
699 continue;
700 }
701
702 let feature_flag = all_feature_flags
703 .iter()
704 .find(|flag| flag.flag == feature_flag)
705 .context("failed to find feature flag: {feature_flag:?}")?;
706
707 app.db.add_user_flag(user.id, feature_flag.id).await?;
708 }
709
710 Ok(Json(MigrateToNewBillingResponse {
711 canceled_subscription_id: canceled_subscription_id
712 .map(|subscription_id| subscription_id.to_string()),
713 }))
714}
715
716#[derive(Debug, Deserialize)]
717struct SyncBillingSubscriptionBody {
718 github_user_id: i32,
719}
720
721#[derive(Debug, Serialize)]
722struct SyncBillingSubscriptionResponse {
723 stripe_customer_id: String,
724}
725
726async fn sync_billing_subscription(
727 Extension(app): Extension<Arc<AppState>>,
728 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
729) -> Result<Json<SyncBillingSubscriptionResponse>> {
730 let Some(stripe_client) = app.stripe_client.clone() else {
731 log::error!("failed to retrieve Stripe client");
732 Err(Error::http(
733 StatusCode::NOT_IMPLEMENTED,
734 "not supported".into(),
735 ))?
736 };
737
738 let user = app
739 .db
740 .get_user_by_github_user_id(body.github_user_id)
741 .await?
742 .context("user not found")?;
743
744 let billing_customer = app
745 .db
746 .get_billing_customer_by_user_id(user.id)
747 .await?
748 .context("billing customer not found")?;
749 let stripe_customer_id = billing_customer
750 .stripe_customer_id
751 .parse::<stripe::CustomerId>()
752 .context("failed to parse Stripe customer ID from database")?;
753
754 let subscriptions = Subscription::list(
755 &stripe_client,
756 &stripe::ListSubscriptions {
757 customer: Some(stripe_customer_id),
758 // Sync all non-canceled subscriptions.
759 status: None,
760 ..Default::default()
761 },
762 )
763 .await?;
764
765 for subscription in subscriptions.data {
766 let subscription_id = subscription.id.clone();
767
768 sync_subscription(&app, &stripe_client, subscription)
769 .await
770 .with_context(|| {
771 format!(
772 "failed to sync subscription {subscription_id} for user {}",
773 user.id,
774 )
775 })?;
776 }
777
778 Ok(Json(SyncBillingSubscriptionResponse {
779 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
780 }))
781}
782
783/// The amount of time we wait in between each poll of Stripe events.
784///
785/// This value should strike a balance between:
786/// 1. Being short enough that we update quickly when something in Stripe changes
787/// 2. Being long enough that we don't eat into our rate limits.
788///
789/// As a point of reference, the Sequin folks say they have this at **500ms**:
790///
791/// > We poll the Stripe /events endpoint every 500ms per account
792/// >
793/// > — https://blog.sequinstream.com/events-not-webhooks/
794const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
795
796/// The maximum number of events to return per page.
797///
798/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
799///
800/// > Limit can range between 1 and 100, and the default is 10.
801const EVENTS_LIMIT_PER_PAGE: u64 = 100;
802
803/// The number of pages consisting entirely of already-processed events that we
804/// will see before we stop retrieving events.
805///
806/// This is used to prevent over-fetching the Stripe events API for events we've
807/// already seen and processed.
808const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
809
810/// Polls the Stripe events API periodically to reconcile the records in our
811/// database with the data in Stripe.
812pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
813 let Some(stripe_client) = app.stripe_client.clone() else {
814 log::warn!("failed to retrieve Stripe client");
815 return;
816 };
817
818 let executor = app.executor.clone();
819 executor.spawn_detached({
820 let executor = executor.clone();
821 async move {
822 loop {
823 poll_stripe_events(&app, &rpc_server, &stripe_client)
824 .await
825 .log_err();
826
827 executor.sleep(POLL_EVENTS_INTERVAL).await;
828 }
829 }
830 });
831}
832
833async fn poll_stripe_events(
834 app: &Arc<AppState>,
835 rpc_server: &Arc<Server>,
836 stripe_client: &stripe::Client,
837) -> anyhow::Result<()> {
838 fn event_type_to_string(event_type: EventType) -> String {
839 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
840 // so we need to unquote it.
841 event_type.to_string().trim_matches('"').to_string()
842 }
843
844 let event_types = [
845 EventType::CustomerCreated,
846 EventType::CustomerUpdated,
847 EventType::CustomerSubscriptionCreated,
848 EventType::CustomerSubscriptionUpdated,
849 EventType::CustomerSubscriptionPaused,
850 EventType::CustomerSubscriptionResumed,
851 EventType::CustomerSubscriptionDeleted,
852 ]
853 .into_iter()
854 .map(event_type_to_string)
855 .collect::<Vec<_>>();
856
857 let mut pages_of_already_processed_events = 0;
858 let mut unprocessed_events = Vec::new();
859
860 log::info!(
861 "Stripe events: starting retrieval for {}",
862 event_types.join(", ")
863 );
864 let mut params = ListEvents::new();
865 params.types = Some(event_types.clone());
866 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
867
868 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
869 .await?
870 .paginate(params);
871
872 loop {
873 let processed_event_ids = {
874 let event_ids = event_pages
875 .page
876 .data
877 .iter()
878 .map(|event| event.id.as_str())
879 .collect::<Vec<_>>();
880 app.db
881 .get_processed_stripe_events_by_event_ids(&event_ids)
882 .await?
883 .into_iter()
884 .map(|event| event.stripe_event_id)
885 .collect::<Vec<_>>()
886 };
887
888 let mut processed_events_in_page = 0;
889 let events_in_page = event_pages.page.data.len();
890 for event in &event_pages.page.data {
891 if processed_event_ids.contains(&event.id.to_string()) {
892 processed_events_in_page += 1;
893 log::debug!("Stripe events: already processed '{}', skipping", event.id);
894 } else {
895 unprocessed_events.push(event.clone());
896 }
897 }
898
899 if processed_events_in_page == events_in_page {
900 pages_of_already_processed_events += 1;
901 }
902
903 if event_pages.page.has_more {
904 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
905 {
906 log::info!(
907 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
908 );
909 break;
910 } else {
911 log::info!("Stripe events: retrieving next page");
912 event_pages = event_pages.next(&stripe_client).await?;
913 }
914 } else {
915 break;
916 }
917 }
918
919 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
920
921 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
922 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
923
924 for event in unprocessed_events {
925 let event_id = event.id.clone();
926 let processed_event_params = CreateProcessedStripeEventParams {
927 stripe_event_id: event.id.to_string(),
928 stripe_event_type: event_type_to_string(event.type_),
929 stripe_event_created_timestamp: event.created,
930 };
931
932 // If the event has happened too far in the past, we don't want to
933 // process it and risk overwriting other more-recent updates.
934 //
935 // 1 day was chosen arbitrarily. This could be made longer or shorter.
936 let one_day = Duration::from_secs(24 * 60 * 60);
937 let a_day_ago = Utc::now() - one_day;
938 if a_day_ago.timestamp() > event.created {
939 log::info!(
940 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
941 event_id
942 );
943 app.db
944 .create_processed_stripe_event(&processed_event_params)
945 .await?;
946
947 continue;
948 }
949
950 let process_result = match event.type_ {
951 EventType::CustomerCreated | EventType::CustomerUpdated => {
952 handle_customer_event(app, stripe_client, event).await
953 }
954 EventType::CustomerSubscriptionCreated
955 | EventType::CustomerSubscriptionUpdated
956 | EventType::CustomerSubscriptionPaused
957 | EventType::CustomerSubscriptionResumed
958 | EventType::CustomerSubscriptionDeleted => {
959 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
960 }
961 _ => Ok(()),
962 };
963
964 if let Some(()) = process_result
965 .with_context(|| format!("failed to process event {event_id} successfully"))
966 .log_err()
967 {
968 app.db
969 .create_processed_stripe_event(&processed_event_params)
970 .await?;
971 }
972 }
973
974 Ok(())
975}
976
977async fn handle_customer_event(
978 app: &Arc<AppState>,
979 _stripe_client: &stripe::Client,
980 event: stripe::Event,
981) -> anyhow::Result<()> {
982 let EventObject::Customer(customer) = event.data.object else {
983 bail!("unexpected event payload for {}", event.id);
984 };
985
986 log::info!("handling Stripe {} event: {}", event.type_, event.id);
987
988 let Some(email) = customer.email else {
989 log::info!("Stripe customer has no email: skipping");
990 return Ok(());
991 };
992
993 let Some(user) = app.db.get_user_by_email(&email).await? else {
994 log::info!("no user found for email: skipping");
995 return Ok(());
996 };
997
998 if let Some(existing_customer) = app
999 .db
1000 .get_billing_customer_by_stripe_customer_id(&customer.id)
1001 .await?
1002 {
1003 app.db
1004 .update_billing_customer(
1005 existing_customer.id,
1006 &UpdateBillingCustomerParams {
1007 // For now we just leave the information as-is, as it is not
1008 // likely to change.
1009 ..Default::default()
1010 },
1011 )
1012 .await?;
1013 } else {
1014 app.db
1015 .create_billing_customer(&CreateBillingCustomerParams {
1016 user_id: user.id,
1017 stripe_customer_id: customer.id.to_string(),
1018 })
1019 .await?;
1020 }
1021
1022 Ok(())
1023}
1024
1025async fn sync_subscription(
1026 app: &Arc<AppState>,
1027 stripe_client: &stripe::Client,
1028 subscription: stripe::Subscription,
1029) -> anyhow::Result<billing_customer::Model> {
1030 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1031 stripe_billing
1032 .determine_subscription_kind(&subscription)
1033 .await
1034 } else {
1035 None
1036 };
1037
1038 let billing_customer =
1039 find_or_create_billing_customer(app, stripe_client, subscription.customer)
1040 .await?
1041 .context("billing customer not found")?;
1042
1043 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1044 if subscription.status == SubscriptionStatus::Trialing {
1045 let current_period_start =
1046 DateTime::from_timestamp(subscription.current_period_start, 0)
1047 .context("No trial subscription period start")?;
1048
1049 app.db
1050 .update_billing_customer(
1051 billing_customer.id,
1052 &UpdateBillingCustomerParams {
1053 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1054 ..Default::default()
1055 },
1056 )
1057 .await?;
1058 }
1059 }
1060
1061 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1062 && subscription
1063 .cancellation_details
1064 .as_ref()
1065 .and_then(|details| details.reason)
1066 .map_or(false, |reason| {
1067 reason == CancellationDetailsReason::PaymentFailed
1068 });
1069
1070 if was_canceled_due_to_payment_failure {
1071 app.db
1072 .update_billing_customer(
1073 billing_customer.id,
1074 &UpdateBillingCustomerParams {
1075 has_overdue_invoices: ActiveValue::set(true),
1076 ..Default::default()
1077 },
1078 )
1079 .await?;
1080 }
1081
1082 if let Some(existing_subscription) = app
1083 .db
1084 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1085 .await?
1086 {
1087 app.db
1088 .update_billing_subscription(
1089 existing_subscription.id,
1090 &UpdateBillingSubscriptionParams {
1091 billing_customer_id: ActiveValue::set(billing_customer.id),
1092 kind: ActiveValue::set(subscription_kind),
1093 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1094 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1095 stripe_cancel_at: ActiveValue::set(
1096 subscription
1097 .cancel_at
1098 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1099 .map(|time| time.naive_utc()),
1100 ),
1101 stripe_cancellation_reason: ActiveValue::set(
1102 subscription
1103 .cancellation_details
1104 .and_then(|details| details.reason)
1105 .map(|reason| reason.into()),
1106 ),
1107 stripe_current_period_start: ActiveValue::set(Some(
1108 subscription.current_period_start,
1109 )),
1110 stripe_current_period_end: ActiveValue::set(Some(
1111 subscription.current_period_end,
1112 )),
1113 },
1114 )
1115 .await?;
1116 } else {
1117 if let Some(existing_subscription) = app
1118 .db
1119 .get_active_billing_subscription(billing_customer.user_id)
1120 .await?
1121 {
1122 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1123 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1124 {
1125 let stripe_subscription_id = existing_subscription
1126 .stripe_subscription_id
1127 .parse::<stripe::SubscriptionId>()
1128 .context("failed to parse Stripe subscription ID from database")?;
1129
1130 Subscription::cancel(
1131 &stripe_client,
1132 &stripe_subscription_id,
1133 stripe::CancelSubscription {
1134 invoice_now: None,
1135 ..Default::default()
1136 },
1137 )
1138 .await?;
1139 } else {
1140 // If the user already has an active billing subscription, ignore the
1141 // event and return an `Ok` to signal that it was processed
1142 // successfully.
1143 //
1144 // There is the possibility that this could cause us to not create a
1145 // subscription in the following scenario:
1146 //
1147 // 1. User has an active subscription A
1148 // 2. User cancels subscription A
1149 // 3. User creates a new subscription B
1150 // 4. We process the new subscription B before the cancellation of subscription A
1151 // 5. User ends up with no subscriptions
1152 //
1153 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1154
1155 log::info!(
1156 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1157 user_id = billing_customer.user_id,
1158 subscription_id = subscription.id
1159 );
1160 return Ok(billing_customer);
1161 }
1162 }
1163
1164 app.db
1165 .create_billing_subscription(&CreateBillingSubscriptionParams {
1166 billing_customer_id: billing_customer.id,
1167 kind: subscription_kind,
1168 stripe_subscription_id: subscription.id.to_string(),
1169 stripe_subscription_status: subscription.status.into(),
1170 stripe_cancellation_reason: subscription
1171 .cancellation_details
1172 .and_then(|details| details.reason)
1173 .map(|reason| reason.into()),
1174 stripe_current_period_start: Some(subscription.current_period_start),
1175 stripe_current_period_end: Some(subscription.current_period_end),
1176 })
1177 .await?;
1178 }
1179
1180 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1181 if subscription.status == SubscriptionStatus::Canceled
1182 || subscription.status == SubscriptionStatus::Paused
1183 {
1184 let already_has_active_billing_subscription = app
1185 .db
1186 .has_active_billing_subscription(billing_customer.user_id)
1187 .await?;
1188 if !already_has_active_billing_subscription {
1189 let stripe_customer_id = billing_customer
1190 .stripe_customer_id
1191 .parse::<stripe::CustomerId>()
1192 .context("failed to parse Stripe customer ID from database")?;
1193
1194 stripe_billing
1195 .subscribe_to_zed_free(stripe_customer_id)
1196 .await?;
1197 }
1198 }
1199 }
1200
1201 Ok(billing_customer)
1202}
1203
1204async fn handle_customer_subscription_event(
1205 app: &Arc<AppState>,
1206 rpc_server: &Arc<Server>,
1207 stripe_client: &stripe::Client,
1208 event: stripe::Event,
1209) -> anyhow::Result<()> {
1210 let EventObject::Subscription(subscription) = event.data.object else {
1211 bail!("unexpected event payload for {}", event.id);
1212 };
1213
1214 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1215
1216 let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1217
1218 // When the user's subscription changes, push down any changes to their plan.
1219 rpc_server
1220 .update_plan_for_user(billing_customer.user_id)
1221 .await
1222 .trace_err();
1223
1224 // When the user's subscription changes, we want to refresh their LLM tokens
1225 // to either grant/revoke access.
1226 rpc_server
1227 .refresh_llm_tokens_for_user(billing_customer.user_id)
1228 .await;
1229
1230 Ok(())
1231}
1232
1233#[derive(Debug, Deserialize)]
1234struct GetCurrentUsageParams {
1235 github_user_id: i32,
1236}
1237
1238#[derive(Debug, Serialize)]
1239struct UsageCounts {
1240 pub used: i32,
1241 pub limit: Option<i32>,
1242 pub remaining: Option<i32>,
1243}
1244
1245#[derive(Debug, Serialize)]
1246struct ModelRequestUsage {
1247 pub model: String,
1248 pub mode: CompletionMode,
1249 pub requests: i32,
1250}
1251
1252#[derive(Debug, Serialize)]
1253struct CurrentUsage {
1254 pub model_requests: UsageCounts,
1255 pub model_request_usage: Vec<ModelRequestUsage>,
1256 pub edit_predictions: UsageCounts,
1257}
1258
1259#[derive(Debug, Default, Serialize)]
1260struct GetCurrentUsageResponse {
1261 pub plan: String,
1262 pub current_usage: Option<CurrentUsage>,
1263}
1264
1265async fn get_current_usage(
1266 Extension(app): Extension<Arc<AppState>>,
1267 Query(params): Query<GetCurrentUsageParams>,
1268) -> Result<Json<GetCurrentUsageResponse>> {
1269 let user = app
1270 .db
1271 .get_user_by_github_user_id(params.github_user_id)
1272 .await?
1273 .context("user not found")?;
1274
1275 let feature_flags = app.db.get_user_flags(user.id).await?;
1276 let has_extended_trial = feature_flags
1277 .iter()
1278 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1279
1280 let Some(llm_db) = app.llm_db.clone() else {
1281 return Err(Error::http(
1282 StatusCode::NOT_IMPLEMENTED,
1283 "LLM database not available".into(),
1284 ));
1285 };
1286
1287 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1288 return Ok(Json(GetCurrentUsageResponse::default()));
1289 };
1290
1291 let subscription_period = maybe!({
1292 let period_start_at = subscription.current_period_start_at()?;
1293 let period_end_at = subscription.current_period_end_at()?;
1294
1295 Some((period_start_at, period_end_at))
1296 });
1297
1298 let Some((period_start_at, period_end_at)) = subscription_period else {
1299 return Ok(Json(GetCurrentUsageResponse::default()));
1300 };
1301
1302 let usage = llm_db
1303 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1304 .await?;
1305
1306 let plan = subscription
1307 .kind
1308 .map(Into::into)
1309 .unwrap_or(zed_llm_client::Plan::ZedFree);
1310
1311 let model_requests_limit = match plan.model_requests_limit() {
1312 zed_llm_client::UsageLimit::Limited(limit) => {
1313 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1314 1_000
1315 } else {
1316 limit
1317 };
1318
1319 Some(limit)
1320 }
1321 zed_llm_client::UsageLimit::Unlimited => None,
1322 };
1323
1324 let edit_predictions_limit = match plan.edit_predictions_limit() {
1325 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1326 zed_llm_client::UsageLimit::Unlimited => None,
1327 };
1328
1329 let Some(usage) = usage else {
1330 return Ok(Json(GetCurrentUsageResponse {
1331 plan: plan.as_str().to_string(),
1332 current_usage: Some(CurrentUsage {
1333 model_requests: UsageCounts {
1334 used: 0,
1335 limit: model_requests_limit,
1336 remaining: model_requests_limit,
1337 },
1338 model_request_usage: Vec::new(),
1339 edit_predictions: UsageCounts {
1340 used: 0,
1341 limit: edit_predictions_limit,
1342 remaining: edit_predictions_limit,
1343 },
1344 }),
1345 }));
1346 };
1347
1348 let subscription_usage_meters = llm_db
1349 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1350 .await?;
1351
1352 let model_request_usage = subscription_usage_meters
1353 .into_iter()
1354 .filter_map(|(usage_meter, _usage)| {
1355 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1356
1357 Some(ModelRequestUsage {
1358 model: model.name.clone(),
1359 mode: usage_meter.mode,
1360 requests: usage_meter.requests,
1361 })
1362 })
1363 .collect::<Vec<_>>();
1364
1365 Ok(Json(GetCurrentUsageResponse {
1366 plan: plan.as_str().to_string(),
1367 current_usage: Some(CurrentUsage {
1368 model_requests: UsageCounts {
1369 used: usage.model_requests,
1370 limit: model_requests_limit,
1371 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1372 },
1373 model_request_usage,
1374 edit_predictions: UsageCounts {
1375 used: usage.edit_predictions,
1376 limit: edit_predictions_limit,
1377 remaining: edit_predictions_limit
1378 .map(|limit| (limit - usage.edit_predictions).max(0)),
1379 },
1380 }),
1381 }))
1382}
1383
1384impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1385 fn from(value: SubscriptionStatus) -> Self {
1386 match value {
1387 SubscriptionStatus::Incomplete => Self::Incomplete,
1388 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1389 SubscriptionStatus::Trialing => Self::Trialing,
1390 SubscriptionStatus::Active => Self::Active,
1391 SubscriptionStatus::PastDue => Self::PastDue,
1392 SubscriptionStatus::Canceled => Self::Canceled,
1393 SubscriptionStatus::Unpaid => Self::Unpaid,
1394 SubscriptionStatus::Paused => Self::Paused,
1395 }
1396 }
1397}
1398
1399impl From<CancellationDetailsReason> for StripeCancellationReason {
1400 fn from(value: CancellationDetailsReason) -> Self {
1401 match value {
1402 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1403 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1404 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1405 }
1406 }
1407}
1408
1409/// Finds or creates a billing customer using the provided customer.
1410pub async fn find_or_create_billing_customer(
1411 app: &Arc<AppState>,
1412 stripe_client: &stripe::Client,
1413 customer_or_id: Expandable<Customer>,
1414) -> anyhow::Result<Option<billing_customer::Model>> {
1415 let customer_id = match &customer_or_id {
1416 Expandable::Id(id) => id,
1417 Expandable::Object(customer) => customer.id.as_ref(),
1418 };
1419
1420 // If we already have a billing customer record associated with the Stripe customer,
1421 // there's nothing more we need to do.
1422 if let Some(billing_customer) = app
1423 .db
1424 .get_billing_customer_by_stripe_customer_id(customer_id)
1425 .await?
1426 {
1427 return Ok(Some(billing_customer));
1428 }
1429
1430 // If all we have is a customer ID, resolve it to a full customer record by
1431 // hitting the Stripe API.
1432 let customer = match customer_or_id {
1433 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1434 Expandable::Object(customer) => *customer,
1435 };
1436
1437 let Some(email) = customer.email else {
1438 return Ok(None);
1439 };
1440
1441 let Some(user) = app.db.get_user_by_email(&email).await? else {
1442 return Ok(None);
1443 };
1444
1445 let billing_customer = app
1446 .db
1447 .create_billing_customer(&CreateBillingCustomerParams {
1448 user_id: user.id,
1449 stripe_customer_id: customer.id.to_string(),
1450 })
1451 .await?;
1452
1453 Ok(Some(billing_customer))
1454}
1455
1456const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1457
1458pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1459 let Some(stripe_billing) = app.stripe_billing.clone() else {
1460 log::warn!("failed to retrieve Stripe billing object");
1461 return;
1462 };
1463 let Some(llm_db) = app.llm_db.clone() else {
1464 log::warn!("failed to retrieve LLM database");
1465 return;
1466 };
1467
1468 let executor = app.executor.clone();
1469 executor.spawn_detached({
1470 let executor = executor.clone();
1471 async move {
1472 loop {
1473 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1474 .await
1475 .context("failed to sync LLM request usage to Stripe")
1476 .trace_err();
1477 executor
1478 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1479 .await;
1480 }
1481 }
1482 });
1483}
1484
1485async fn sync_model_request_usage_with_stripe(
1486 app: &Arc<AppState>,
1487 llm_db: &Arc<LlmDatabase>,
1488 stripe_billing: &Arc<StripeBilling>,
1489) -> anyhow::Result<()> {
1490 let staff_users = app.db.get_staff_users().await?;
1491 let staff_user_ids = staff_users
1492 .iter()
1493 .map(|user| user.id)
1494 .collect::<HashSet<UserId>>();
1495
1496 let usage_meters = llm_db
1497 .get_current_subscription_usage_meters(Utc::now())
1498 .await?;
1499 let usage_meters = usage_meters
1500 .into_iter()
1501 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1502 .collect::<Vec<_>>();
1503 let user_ids = usage_meters
1504 .iter()
1505 .map(|(_, usage)| usage.user_id)
1506 .collect::<HashSet<UserId>>();
1507 let billing_subscriptions = app
1508 .db
1509 .get_active_zed_pro_billing_subscriptions(user_ids)
1510 .await?;
1511
1512 let claude_sonnet_4 = stripe_billing
1513 .find_price_by_lookup_key("claude-sonnet-4-requests")
1514 .await?;
1515 let claude_sonnet_4_max = stripe_billing
1516 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1517 .await?;
1518 let claude_3_5_sonnet = stripe_billing
1519 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1520 .await?;
1521 let claude_3_7_sonnet = stripe_billing
1522 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1523 .await?;
1524 let claude_3_7_sonnet_max = stripe_billing
1525 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1526 .await?;
1527
1528 for (usage_meter, usage) in usage_meters {
1529 maybe!(async {
1530 let Some((billing_customer, billing_subscription)) =
1531 billing_subscriptions.get(&usage.user_id)
1532 else {
1533 bail!(
1534 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1535 usage.user_id
1536 );
1537 };
1538
1539 let stripe_customer_id = billing_customer
1540 .stripe_customer_id
1541 .parse::<stripe::CustomerId>()
1542 .context("failed to parse Stripe customer ID from database")?;
1543 let stripe_subscription_id = billing_subscription
1544 .stripe_subscription_id
1545 .parse::<stripe::SubscriptionId>()
1546 .context("failed to parse Stripe subscription ID from database")?;
1547
1548 let model = llm_db.model_by_id(usage_meter.model_id)?;
1549
1550 let (price, meter_event_name) = match model.name.as_str() {
1551 "claude-sonnet-4" => match usage_meter.mode {
1552 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1553 CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
1554 },
1555 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1556 "claude-3-7-sonnet" => match usage_meter.mode {
1557 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1558 CompletionMode::Max => {
1559 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1560 }
1561 },
1562 model_name => {
1563 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1564 }
1565 };
1566
1567 stripe_billing
1568 .subscribe_to_price(&stripe_subscription_id, price)
1569 .await?;
1570 stripe_billing
1571 .bill_model_request_usage(
1572 &stripe_customer_id,
1573 meter_event_name,
1574 usage_meter.requests,
1575 )
1576 .await?;
1577
1578 Ok(())
1579 })
1580 .await
1581 .log_err();
1582 }
1583
1584 Ok(())
1585}