1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
21 Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
31use crate::rpc::{ResultExt as _, Server};
32use crate::{AppState, Error, Result};
33use crate::{db::UserId, llm::db::LlmDatabase};
34use crate::{
35 db::{
36 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
37 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
38 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
39 },
40 stripe_billing::StripeBilling,
41};
42
43pub fn router() -> Router {
44 Router::new()
45 .route(
46 "/billing/preferences",
47 get(get_billing_preferences).put(update_billing_preferences),
48 )
49 .route(
50 "/billing/subscriptions",
51 get(list_billing_subscriptions).post(create_billing_subscription),
52 )
53 .route(
54 "/billing/subscriptions/manage",
55 post(manage_billing_subscription),
56 )
57 .route(
58 "/billing/subscriptions/migrate",
59 post(migrate_to_new_billing),
60 )
61 .route(
62 "/billing/subscriptions/sync",
63 post(sync_billing_subscription),
64 )
65 .route("/billing/usage", get(get_current_usage))
66}
67
68#[derive(Debug, Deserialize)]
69struct GetBillingPreferencesParams {
70 github_user_id: i32,
71}
72
73#[derive(Debug, Serialize)]
74struct BillingPreferencesResponse {
75 trial_started_at: Option<String>,
76 max_monthly_llm_usage_spending_in_cents: i32,
77 model_request_overages_enabled: bool,
78 model_request_overages_spend_limit_in_cents: i32,
79}
80
81async fn get_billing_preferences(
82 Extension(app): Extension<Arc<AppState>>,
83 Query(params): Query<GetBillingPreferencesParams>,
84) -> Result<Json<BillingPreferencesResponse>> {
85 let user = app
86 .db
87 .get_user_by_github_user_id(params.github_user_id)
88 .await?
89 .context("user not found")?;
90
91 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
92 let preferences = app.db.get_billing_preferences(user.id).await?;
93
94 Ok(Json(BillingPreferencesResponse {
95 trial_started_at: billing_customer
96 .and_then(|billing_customer| billing_customer.trial_started_at)
97 .map(|trial_started_at| {
98 trial_started_at
99 .and_utc()
100 .to_rfc3339_opts(SecondsFormat::Millis, true)
101 }),
102 max_monthly_llm_usage_spending_in_cents: preferences
103 .as_ref()
104 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
105 preferences.max_monthly_llm_usage_spending_in_cents
106 }),
107 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
108 preferences.model_request_overages_enabled
109 }),
110 model_request_overages_spend_limit_in_cents: preferences
111 .as_ref()
112 .map_or(0, |preferences| {
113 preferences.model_request_overages_spend_limit_in_cents
114 }),
115 }))
116}
117
118#[derive(Debug, Deserialize)]
119struct UpdateBillingPreferencesBody {
120 github_user_id: i32,
121 #[serde(default)]
122 max_monthly_llm_usage_spending_in_cents: i32,
123 #[serde(default)]
124 model_request_overages_enabled: bool,
125 #[serde(default)]
126 model_request_overages_spend_limit_in_cents: i32,
127}
128
129async fn update_billing_preferences(
130 Extension(app): Extension<Arc<AppState>>,
131 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
132 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
133) -> Result<Json<BillingPreferencesResponse>> {
134 let user = app
135 .db
136 .get_user_by_github_user_id(body.github_user_id)
137 .await?
138 .context("user not found")?;
139
140 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
141
142 let max_monthly_llm_usage_spending_in_cents =
143 body.max_monthly_llm_usage_spending_in_cents.max(0);
144 let model_request_overages_spend_limit_in_cents =
145 body.model_request_overages_spend_limit_in_cents.max(0);
146
147 let billing_preferences =
148 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
149 app.db
150 .update_billing_preferences(
151 user.id,
152 &UpdateBillingPreferencesParams {
153 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
154 max_monthly_llm_usage_spending_in_cents,
155 ),
156 model_request_overages_enabled: ActiveValue::set(
157 body.model_request_overages_enabled,
158 ),
159 model_request_overages_spend_limit_in_cents: ActiveValue::set(
160 model_request_overages_spend_limit_in_cents,
161 ),
162 },
163 )
164 .await?
165 } else {
166 app.db
167 .create_billing_preferences(
168 user.id,
169 &crate::db::CreateBillingPreferencesParams {
170 max_monthly_llm_usage_spending_in_cents,
171 model_request_overages_enabled: body.model_request_overages_enabled,
172 model_request_overages_spend_limit_in_cents,
173 },
174 )
175 .await?
176 };
177
178 SnowflakeRow::new(
179 "Billing Preferences Updated",
180 Some(user.metrics_id),
181 user.admin,
182 None,
183 json!({
184 "user_id": user.id,
185 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
186 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
187 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
188 }),
189 )
190 .write(&app.kinesis_client, &app.config.kinesis_stream)
191 .await
192 .log_err();
193
194 rpc_server.refresh_llm_tokens_for_user(user.id).await;
195
196 Ok(Json(BillingPreferencesResponse {
197 trial_started_at: billing_customer
198 .and_then(|billing_customer| billing_customer.trial_started_at)
199 .map(|trial_started_at| {
200 trial_started_at
201 .and_utc()
202 .to_rfc3339_opts(SecondsFormat::Millis, true)
203 }),
204 max_monthly_llm_usage_spending_in_cents: billing_preferences
205 .max_monthly_llm_usage_spending_in_cents,
206 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
207 model_request_overages_spend_limit_in_cents: billing_preferences
208 .model_request_overages_spend_limit_in_cents,
209 }))
210}
211
212#[derive(Debug, Deserialize)]
213struct ListBillingSubscriptionsParams {
214 github_user_id: i32,
215}
216
217#[derive(Debug, Serialize)]
218struct BillingSubscriptionJson {
219 id: BillingSubscriptionId,
220 name: String,
221 status: StripeSubscriptionStatus,
222 trial_end_at: Option<String>,
223 cancel_at: Option<String>,
224 /// Whether this subscription can be canceled.
225 is_cancelable: bool,
226}
227
228#[derive(Debug, Serialize)]
229struct ListBillingSubscriptionsResponse {
230 subscriptions: Vec<BillingSubscriptionJson>,
231}
232
233async fn list_billing_subscriptions(
234 Extension(app): Extension<Arc<AppState>>,
235 Query(params): Query<ListBillingSubscriptionsParams>,
236) -> Result<Json<ListBillingSubscriptionsResponse>> {
237 let user = app
238 .db
239 .get_user_by_github_user_id(params.github_user_id)
240 .await?
241 .context("user not found")?;
242
243 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
244
245 Ok(Json(ListBillingSubscriptionsResponse {
246 subscriptions: subscriptions
247 .into_iter()
248 .map(|subscription| BillingSubscriptionJson {
249 id: subscription.id,
250 name: match subscription.kind {
251 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
252 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
253 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
254 None => "Zed LLM Usage".to_string(),
255 },
256 status: subscription.stripe_subscription_status,
257 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
258 maybe!({
259 let end_at = subscription.stripe_current_period_end?;
260 let end_at = DateTime::from_timestamp(end_at, 0)?;
261
262 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
263 })
264 } else {
265 None
266 },
267 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
268 cancel_at
269 .and_utc()
270 .to_rfc3339_opts(SecondsFormat::Millis, true)
271 }),
272 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
273 && subscription.stripe_subscription_status.is_cancelable()
274 && subscription.stripe_cancel_at.is_none(),
275 })
276 .collect(),
277 }))
278}
279
280#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
281#[serde(rename_all = "snake_case")]
282enum ProductCode {
283 ZedPro,
284 ZedProTrial,
285 ZedFree,
286}
287
288#[derive(Debug, Deserialize)]
289struct CreateBillingSubscriptionBody {
290 github_user_id: i32,
291 product: ProductCode,
292}
293
294#[derive(Debug, Serialize)]
295struct CreateBillingSubscriptionResponse {
296 checkout_session_url: String,
297}
298
299/// Initiates a Stripe Checkout session for creating a billing subscription.
300async fn create_billing_subscription(
301 Extension(app): Extension<Arc<AppState>>,
302 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
303) -> Result<Json<CreateBillingSubscriptionResponse>> {
304 let user = app
305 .db
306 .get_user_by_github_user_id(body.github_user_id)
307 .await?
308 .context("user not found")?;
309
310 let Some(stripe_billing) = app.stripe_billing.clone() else {
311 log::error!("failed to retrieve Stripe billing object");
312 Err(Error::http(
313 StatusCode::NOT_IMPLEMENTED,
314 "not supported".into(),
315 ))?
316 };
317
318 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
319 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
320 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
321
322 if !is_checkout_allowed {
323 return Err(Error::http(
324 StatusCode::CONFLICT,
325 "user already has an active subscription".into(),
326 ));
327 }
328 }
329
330 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
331 if let Some(existing_billing_customer) = &existing_billing_customer {
332 if existing_billing_customer.has_overdue_invoices {
333 return Err(Error::http(
334 StatusCode::PAYMENT_REQUIRED,
335 "user has overdue invoices".into(),
336 ));
337 }
338 }
339
340 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
341 CustomerId::from_str(&existing_customer.stripe_customer_id)
342 .context("failed to parse customer ID")?
343 } else {
344 stripe_billing
345 .find_or_create_customer_by_email(user.email_address.as_deref())
346 .await?
347 .try_into()?
348 };
349
350 let success_url = format!(
351 "{}/account?checkout_complete=1",
352 app.config.zed_dot_dev_url()
353 );
354
355 let checkout_session_url = match body.product {
356 ProductCode::ZedPro => {
357 stripe_billing
358 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
359 .await?
360 }
361 ProductCode::ZedProTrial => {
362 if let Some(existing_billing_customer) = &existing_billing_customer {
363 if existing_billing_customer.trial_started_at.is_some() {
364 return Err(Error::http(
365 StatusCode::FORBIDDEN,
366 "user already used free trial".into(),
367 ));
368 }
369 }
370
371 let feature_flags = app.db.get_user_flags(user.id).await?;
372
373 stripe_billing
374 .checkout_with_zed_pro_trial(
375 customer_id,
376 &user.github_login,
377 feature_flags,
378 &success_url,
379 )
380 .await?
381 }
382 ProductCode::ZedFree => {
383 stripe_billing
384 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
385 .await?
386 }
387 };
388
389 Ok(Json(CreateBillingSubscriptionResponse {
390 checkout_session_url,
391 }))
392}
393
394#[derive(Debug, PartialEq, Deserialize)]
395#[serde(rename_all = "snake_case")]
396enum ManageSubscriptionIntent {
397 /// The user intends to manage their subscription.
398 ///
399 /// This will open the Stripe billing portal without putting the user in a specific flow.
400 ManageSubscription,
401 /// The user intends to update their payment method.
402 UpdatePaymentMethod,
403 /// The user intends to upgrade to Zed Pro.
404 UpgradeToPro,
405 /// The user intends to cancel their subscription.
406 Cancel,
407 /// The user intends to stop the cancellation of their subscription.
408 StopCancellation,
409}
410
411#[derive(Debug, Deserialize)]
412struct ManageBillingSubscriptionBody {
413 github_user_id: i32,
414 intent: ManageSubscriptionIntent,
415 /// The ID of the subscription to manage.
416 subscription_id: BillingSubscriptionId,
417 redirect_to: Option<String>,
418}
419
420#[derive(Debug, Serialize)]
421struct ManageBillingSubscriptionResponse {
422 billing_portal_session_url: Option<String>,
423}
424
425/// Initiates a Stripe customer portal session for managing a billing subscription.
426async fn manage_billing_subscription(
427 Extension(app): Extension<Arc<AppState>>,
428 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
429) -> Result<Json<ManageBillingSubscriptionResponse>> {
430 let user = app
431 .db
432 .get_user_by_github_user_id(body.github_user_id)
433 .await?
434 .context("user not found")?;
435
436 let Some(stripe_client) = app.stripe_client.clone() else {
437 log::error!("failed to retrieve Stripe client");
438 Err(Error::http(
439 StatusCode::NOT_IMPLEMENTED,
440 "not supported".into(),
441 ))?
442 };
443
444 let Some(stripe_billing) = app.stripe_billing.clone() else {
445 log::error!("failed to retrieve Stripe billing object");
446 Err(Error::http(
447 StatusCode::NOT_IMPLEMENTED,
448 "not supported".into(),
449 ))?
450 };
451
452 let customer = app
453 .db
454 .get_billing_customer_by_user_id(user.id)
455 .await?
456 .context("billing customer not found")?;
457 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
458 .context("failed to parse customer ID")?;
459
460 let subscription = app
461 .db
462 .get_billing_subscription_by_id(body.subscription_id)
463 .await?
464 .context("subscription not found")?;
465 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
466 .context("failed to parse subscription ID")?;
467
468 if body.intent == ManageSubscriptionIntent::StopCancellation {
469 let updated_stripe_subscription = Subscription::update(
470 &stripe_client,
471 &subscription_id,
472 stripe::UpdateSubscription {
473 cancel_at_period_end: Some(false),
474 ..Default::default()
475 },
476 )
477 .await?;
478
479 app.db
480 .update_billing_subscription(
481 subscription.id,
482 &UpdateBillingSubscriptionParams {
483 stripe_cancel_at: ActiveValue::set(
484 updated_stripe_subscription
485 .cancel_at
486 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
487 .map(|time| time.naive_utc()),
488 ),
489 ..Default::default()
490 },
491 )
492 .await?;
493
494 return Ok(Json(ManageBillingSubscriptionResponse {
495 billing_portal_session_url: None,
496 }));
497 }
498
499 let flow = match body.intent {
500 ManageSubscriptionIntent::ManageSubscription => None,
501 ManageSubscriptionIntent::UpgradeToPro => {
502 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
503 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
504
505 let stripe_subscription =
506 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
507
508 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
509 && stripe_subscription.items.data.iter().any(|item| {
510 item.price
511 .as_ref()
512 .map_or(false, |price| price.id == zed_pro_price_id)
513 });
514 if is_on_zed_pro_trial {
515 let payment_methods = PaymentMethod::list(
516 &stripe_client,
517 &stripe::ListPaymentMethods {
518 customer: Some(stripe_subscription.customer.id()),
519 ..Default::default()
520 },
521 )
522 .await?;
523
524 let has_payment_method = !payment_methods.data.is_empty();
525 if !has_payment_method {
526 return Err(Error::http(
527 StatusCode::BAD_REQUEST,
528 "missing payment method".into(),
529 ));
530 }
531
532 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
533 Subscription::update(
534 &stripe_client,
535 &stripe_subscription.id,
536 stripe::UpdateSubscription {
537 trial_end: Some(stripe::Scheduled::now()),
538 ..Default::default()
539 },
540 )
541 .await?;
542
543 return Ok(Json(ManageBillingSubscriptionResponse {
544 billing_portal_session_url: None,
545 }));
546 }
547
548 let subscription_item_to_update = stripe_subscription
549 .items
550 .data
551 .iter()
552 .find_map(|item| {
553 let price = item.price.as_ref()?;
554
555 if price.id == zed_free_price_id {
556 Some(item.id.clone())
557 } else {
558 None
559 }
560 })
561 .context("No subscription item to update")?;
562
563 Some(CreateBillingPortalSessionFlowData {
564 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
565 subscription_update_confirm: Some(
566 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
567 subscription: subscription.stripe_subscription_id,
568 items: vec![
569 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
570 id: subscription_item_to_update.to_string(),
571 price: Some(zed_pro_price_id.to_string()),
572 quantity: Some(1),
573 },
574 ],
575 discounts: None,
576 },
577 ),
578 ..Default::default()
579 })
580 }
581 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
582 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
583 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
584 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
585 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
586 return_url: format!(
587 "{}{path}",
588 app.config.zed_dot_dev_url(),
589 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
590 ),
591 }),
592 ..Default::default()
593 }),
594 ..Default::default()
595 }),
596 ManageSubscriptionIntent::Cancel => {
597 if subscription.kind == Some(SubscriptionKind::ZedFree) {
598 return Err(Error::http(
599 StatusCode::BAD_REQUEST,
600 "free subscription cannot be canceled".into(),
601 ));
602 }
603
604 Some(CreateBillingPortalSessionFlowData {
605 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
606 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
607 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
608 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
609 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
610 }),
611 ..Default::default()
612 }),
613 subscription_cancel: Some(
614 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
615 subscription: subscription.stripe_subscription_id,
616 retention: None,
617 },
618 ),
619 ..Default::default()
620 })
621 }
622 ManageSubscriptionIntent::StopCancellation => unreachable!(),
623 };
624
625 let mut params = CreateBillingPortalSession::new(customer_id);
626 params.flow_data = flow;
627 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
628 params.return_url = Some(&return_url);
629
630 let session = BillingPortalSession::create(&stripe_client, params).await?;
631
632 Ok(Json(ManageBillingSubscriptionResponse {
633 billing_portal_session_url: Some(session.url),
634 }))
635}
636
637#[derive(Debug, Deserialize)]
638struct MigrateToNewBillingBody {
639 github_user_id: i32,
640}
641
642#[derive(Debug, Serialize)]
643struct MigrateToNewBillingResponse {
644 /// The ID of the subscription that was canceled.
645 canceled_subscription_id: Option<String>,
646}
647
648async fn migrate_to_new_billing(
649 Extension(app): Extension<Arc<AppState>>,
650 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
651) -> Result<Json<MigrateToNewBillingResponse>> {
652 let Some(stripe_client) = app.stripe_client.clone() else {
653 log::error!("failed to retrieve Stripe client");
654 Err(Error::http(
655 StatusCode::NOT_IMPLEMENTED,
656 "not supported".into(),
657 ))?
658 };
659
660 let user = app
661 .db
662 .get_user_by_github_user_id(body.github_user_id)
663 .await?
664 .context("user not found")?;
665
666 let old_billing_subscriptions_by_user = app
667 .db
668 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
669 .await?;
670
671 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
672 old_billing_subscriptions_by_user.get(&user.id)
673 {
674 let stripe_subscription_id = billing_subscription
675 .stripe_subscription_id
676 .parse::<stripe::SubscriptionId>()
677 .context("failed to parse Stripe subscription ID from database")?;
678
679 Subscription::cancel(
680 &stripe_client,
681 &stripe_subscription_id,
682 stripe::CancelSubscription {
683 invoice_now: Some(true),
684 ..Default::default()
685 },
686 )
687 .await?;
688
689 Some(stripe_subscription_id)
690 } else {
691 None
692 };
693
694 let all_feature_flags = app.db.list_feature_flags().await?;
695 let user_feature_flags = app.db.get_user_flags(user.id).await?;
696
697 for feature_flag in ["new-billing", "assistant2"] {
698 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
699 if already_in_feature_flag {
700 continue;
701 }
702
703 let feature_flag = all_feature_flags
704 .iter()
705 .find(|flag| flag.flag == feature_flag)
706 .context("failed to find feature flag: {feature_flag:?}")?;
707
708 app.db.add_user_flag(user.id, feature_flag.id).await?;
709 }
710
711 Ok(Json(MigrateToNewBillingResponse {
712 canceled_subscription_id: canceled_subscription_id
713 .map(|subscription_id| subscription_id.to_string()),
714 }))
715}
716
717#[derive(Debug, Deserialize)]
718struct SyncBillingSubscriptionBody {
719 github_user_id: i32,
720}
721
722#[derive(Debug, Serialize)]
723struct SyncBillingSubscriptionResponse {
724 stripe_customer_id: String,
725}
726
727async fn sync_billing_subscription(
728 Extension(app): Extension<Arc<AppState>>,
729 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
730) -> Result<Json<SyncBillingSubscriptionResponse>> {
731 let Some(stripe_client) = app.stripe_client.clone() else {
732 log::error!("failed to retrieve Stripe client");
733 Err(Error::http(
734 StatusCode::NOT_IMPLEMENTED,
735 "not supported".into(),
736 ))?
737 };
738
739 let user = app
740 .db
741 .get_user_by_github_user_id(body.github_user_id)
742 .await?
743 .context("user not found")?;
744
745 let billing_customer = app
746 .db
747 .get_billing_customer_by_user_id(user.id)
748 .await?
749 .context("billing customer not found")?;
750 let stripe_customer_id = billing_customer
751 .stripe_customer_id
752 .parse::<stripe::CustomerId>()
753 .context("failed to parse Stripe customer ID from database")?;
754
755 let subscriptions = Subscription::list(
756 &stripe_client,
757 &stripe::ListSubscriptions {
758 customer: Some(stripe_customer_id),
759 // Sync all non-canceled subscriptions.
760 status: None,
761 ..Default::default()
762 },
763 )
764 .await?;
765
766 for subscription in subscriptions.data {
767 let subscription_id = subscription.id.clone();
768
769 sync_subscription(&app, &stripe_client, subscription)
770 .await
771 .with_context(|| {
772 format!(
773 "failed to sync subscription {subscription_id} for user {}",
774 user.id,
775 )
776 })?;
777 }
778
779 Ok(Json(SyncBillingSubscriptionResponse {
780 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
781 }))
782}
783
784/// The amount of time we wait in between each poll of Stripe events.
785///
786/// This value should strike a balance between:
787/// 1. Being short enough that we update quickly when something in Stripe changes
788/// 2. Being long enough that we don't eat into our rate limits.
789///
790/// As a point of reference, the Sequin folks say they have this at **500ms**:
791///
792/// > We poll the Stripe /events endpoint every 500ms per account
793/// >
794/// > — https://blog.sequinstream.com/events-not-webhooks/
795const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
796
797/// The maximum number of events to return per page.
798///
799/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
800///
801/// > Limit can range between 1 and 100, and the default is 10.
802const EVENTS_LIMIT_PER_PAGE: u64 = 100;
803
804/// The number of pages consisting entirely of already-processed events that we
805/// will see before we stop retrieving events.
806///
807/// This is used to prevent over-fetching the Stripe events API for events we've
808/// already seen and processed.
809const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
810
811/// Polls the Stripe events API periodically to reconcile the records in our
812/// database with the data in Stripe.
813pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
814 let Some(stripe_client) = app.stripe_client.clone() else {
815 log::warn!("failed to retrieve Stripe client");
816 return;
817 };
818
819 let executor = app.executor.clone();
820 executor.spawn_detached({
821 let executor = executor.clone();
822 async move {
823 loop {
824 poll_stripe_events(&app, &rpc_server, &stripe_client)
825 .await
826 .log_err();
827
828 executor.sleep(POLL_EVENTS_INTERVAL).await;
829 }
830 }
831 });
832}
833
834async fn poll_stripe_events(
835 app: &Arc<AppState>,
836 rpc_server: &Arc<Server>,
837 stripe_client: &stripe::Client,
838) -> anyhow::Result<()> {
839 fn event_type_to_string(event_type: EventType) -> String {
840 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
841 // so we need to unquote it.
842 event_type.to_string().trim_matches('"').to_string()
843 }
844
845 let event_types = [
846 EventType::CustomerCreated,
847 EventType::CustomerUpdated,
848 EventType::CustomerSubscriptionCreated,
849 EventType::CustomerSubscriptionUpdated,
850 EventType::CustomerSubscriptionPaused,
851 EventType::CustomerSubscriptionResumed,
852 EventType::CustomerSubscriptionDeleted,
853 ]
854 .into_iter()
855 .map(event_type_to_string)
856 .collect::<Vec<_>>();
857
858 let mut pages_of_already_processed_events = 0;
859 let mut unprocessed_events = Vec::new();
860
861 log::info!(
862 "Stripe events: starting retrieval for {}",
863 event_types.join(", ")
864 );
865 let mut params = ListEvents::new();
866 params.types = Some(event_types.clone());
867 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
868
869 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
870 .await?
871 .paginate(params);
872
873 loop {
874 let processed_event_ids = {
875 let event_ids = event_pages
876 .page
877 .data
878 .iter()
879 .map(|event| event.id.as_str())
880 .collect::<Vec<_>>();
881 app.db
882 .get_processed_stripe_events_by_event_ids(&event_ids)
883 .await?
884 .into_iter()
885 .map(|event| event.stripe_event_id)
886 .collect::<Vec<_>>()
887 };
888
889 let mut processed_events_in_page = 0;
890 let events_in_page = event_pages.page.data.len();
891 for event in &event_pages.page.data {
892 if processed_event_ids.contains(&event.id.to_string()) {
893 processed_events_in_page += 1;
894 log::debug!("Stripe events: already processed '{}', skipping", event.id);
895 } else {
896 unprocessed_events.push(event.clone());
897 }
898 }
899
900 if processed_events_in_page == events_in_page {
901 pages_of_already_processed_events += 1;
902 }
903
904 if event_pages.page.has_more {
905 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
906 {
907 log::info!(
908 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
909 );
910 break;
911 } else {
912 log::info!("Stripe events: retrieving next page");
913 event_pages = event_pages.next(&stripe_client).await?;
914 }
915 } else {
916 break;
917 }
918 }
919
920 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
921
922 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
923 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
924
925 for event in unprocessed_events {
926 let event_id = event.id.clone();
927 let processed_event_params = CreateProcessedStripeEventParams {
928 stripe_event_id: event.id.to_string(),
929 stripe_event_type: event_type_to_string(event.type_),
930 stripe_event_created_timestamp: event.created,
931 };
932
933 // If the event has happened too far in the past, we don't want to
934 // process it and risk overwriting other more-recent updates.
935 //
936 // 1 day was chosen arbitrarily. This could be made longer or shorter.
937 let one_day = Duration::from_secs(24 * 60 * 60);
938 let a_day_ago = Utc::now() - one_day;
939 if a_day_ago.timestamp() > event.created {
940 log::info!(
941 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
942 event_id
943 );
944 app.db
945 .create_processed_stripe_event(&processed_event_params)
946 .await?;
947
948 continue;
949 }
950
951 let process_result = match event.type_ {
952 EventType::CustomerCreated | EventType::CustomerUpdated => {
953 handle_customer_event(app, stripe_client, event).await
954 }
955 EventType::CustomerSubscriptionCreated
956 | EventType::CustomerSubscriptionUpdated
957 | EventType::CustomerSubscriptionPaused
958 | EventType::CustomerSubscriptionResumed
959 | EventType::CustomerSubscriptionDeleted => {
960 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
961 }
962 _ => Ok(()),
963 };
964
965 if let Some(()) = process_result
966 .with_context(|| format!("failed to process event {event_id} successfully"))
967 .log_err()
968 {
969 app.db
970 .create_processed_stripe_event(&processed_event_params)
971 .await?;
972 }
973 }
974
975 Ok(())
976}
977
978async fn handle_customer_event(
979 app: &Arc<AppState>,
980 _stripe_client: &stripe::Client,
981 event: stripe::Event,
982) -> anyhow::Result<()> {
983 let EventObject::Customer(customer) = event.data.object else {
984 bail!("unexpected event payload for {}", event.id);
985 };
986
987 log::info!("handling Stripe {} event: {}", event.type_, event.id);
988
989 let Some(email) = customer.email else {
990 log::info!("Stripe customer has no email: skipping");
991 return Ok(());
992 };
993
994 let Some(user) = app.db.get_user_by_email(&email).await? else {
995 log::info!("no user found for email: skipping");
996 return Ok(());
997 };
998
999 if let Some(existing_customer) = app
1000 .db
1001 .get_billing_customer_by_stripe_customer_id(&customer.id)
1002 .await?
1003 {
1004 app.db
1005 .update_billing_customer(
1006 existing_customer.id,
1007 &UpdateBillingCustomerParams {
1008 // For now we just leave the information as-is, as it is not
1009 // likely to change.
1010 ..Default::default()
1011 },
1012 )
1013 .await?;
1014 } else {
1015 app.db
1016 .create_billing_customer(&CreateBillingCustomerParams {
1017 user_id: user.id,
1018 stripe_customer_id: customer.id.to_string(),
1019 })
1020 .await?;
1021 }
1022
1023 Ok(())
1024}
1025
1026async fn sync_subscription(
1027 app: &Arc<AppState>,
1028 stripe_client: &stripe::Client,
1029 subscription: stripe::Subscription,
1030) -> anyhow::Result<billing_customer::Model> {
1031 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1032 stripe_billing
1033 .determine_subscription_kind(&subscription)
1034 .await
1035 } else {
1036 None
1037 };
1038
1039 let billing_customer =
1040 find_or_create_billing_customer(app, stripe_client, subscription.customer)
1041 .await?
1042 .context("billing customer not found")?;
1043
1044 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1045 if subscription.status == SubscriptionStatus::Trialing {
1046 let current_period_start =
1047 DateTime::from_timestamp(subscription.current_period_start, 0)
1048 .context("No trial subscription period start")?;
1049
1050 app.db
1051 .update_billing_customer(
1052 billing_customer.id,
1053 &UpdateBillingCustomerParams {
1054 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1055 ..Default::default()
1056 },
1057 )
1058 .await?;
1059 }
1060 }
1061
1062 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1063 && subscription
1064 .cancellation_details
1065 .as_ref()
1066 .and_then(|details| details.reason)
1067 .map_or(false, |reason| {
1068 reason == CancellationDetailsReason::PaymentFailed
1069 });
1070
1071 if was_canceled_due_to_payment_failure {
1072 app.db
1073 .update_billing_customer(
1074 billing_customer.id,
1075 &UpdateBillingCustomerParams {
1076 has_overdue_invoices: ActiveValue::set(true),
1077 ..Default::default()
1078 },
1079 )
1080 .await?;
1081 }
1082
1083 if let Some(existing_subscription) = app
1084 .db
1085 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1086 .await?
1087 {
1088 app.db
1089 .update_billing_subscription(
1090 existing_subscription.id,
1091 &UpdateBillingSubscriptionParams {
1092 billing_customer_id: ActiveValue::set(billing_customer.id),
1093 kind: ActiveValue::set(subscription_kind),
1094 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1095 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1096 stripe_cancel_at: ActiveValue::set(
1097 subscription
1098 .cancel_at
1099 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1100 .map(|time| time.naive_utc()),
1101 ),
1102 stripe_cancellation_reason: ActiveValue::set(
1103 subscription
1104 .cancellation_details
1105 .and_then(|details| details.reason)
1106 .map(|reason| reason.into()),
1107 ),
1108 stripe_current_period_start: ActiveValue::set(Some(
1109 subscription.current_period_start,
1110 )),
1111 stripe_current_period_end: ActiveValue::set(Some(
1112 subscription.current_period_end,
1113 )),
1114 },
1115 )
1116 .await?;
1117 } else {
1118 if let Some(existing_subscription) = app
1119 .db
1120 .get_active_billing_subscription(billing_customer.user_id)
1121 .await?
1122 {
1123 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1124 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1125 {
1126 let stripe_subscription_id = existing_subscription
1127 .stripe_subscription_id
1128 .parse::<stripe::SubscriptionId>()
1129 .context("failed to parse Stripe subscription ID from database")?;
1130
1131 Subscription::cancel(
1132 &stripe_client,
1133 &stripe_subscription_id,
1134 stripe::CancelSubscription {
1135 invoice_now: None,
1136 ..Default::default()
1137 },
1138 )
1139 .await?;
1140 } else {
1141 // If the user already has an active billing subscription, ignore the
1142 // event and return an `Ok` to signal that it was processed
1143 // successfully.
1144 //
1145 // There is the possibility that this could cause us to not create a
1146 // subscription in the following scenario:
1147 //
1148 // 1. User has an active subscription A
1149 // 2. User cancels subscription A
1150 // 3. User creates a new subscription B
1151 // 4. We process the new subscription B before the cancellation of subscription A
1152 // 5. User ends up with no subscriptions
1153 //
1154 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1155
1156 log::info!(
1157 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1158 user_id = billing_customer.user_id,
1159 subscription_id = subscription.id
1160 );
1161 return Ok(billing_customer);
1162 }
1163 }
1164
1165 app.db
1166 .create_billing_subscription(&CreateBillingSubscriptionParams {
1167 billing_customer_id: billing_customer.id,
1168 kind: subscription_kind,
1169 stripe_subscription_id: subscription.id.to_string(),
1170 stripe_subscription_status: subscription.status.into(),
1171 stripe_cancellation_reason: subscription
1172 .cancellation_details
1173 .and_then(|details| details.reason)
1174 .map(|reason| reason.into()),
1175 stripe_current_period_start: Some(subscription.current_period_start),
1176 stripe_current_period_end: Some(subscription.current_period_end),
1177 })
1178 .await?;
1179 }
1180
1181 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1182 if subscription.status == SubscriptionStatus::Canceled
1183 || subscription.status == SubscriptionStatus::Paused
1184 {
1185 let already_has_active_billing_subscription = app
1186 .db
1187 .has_active_billing_subscription(billing_customer.user_id)
1188 .await?;
1189 if !already_has_active_billing_subscription {
1190 let stripe_customer_id = billing_customer
1191 .stripe_customer_id
1192 .parse::<stripe::CustomerId>()
1193 .context("failed to parse Stripe customer ID from database")?;
1194
1195 stripe_billing
1196 .subscribe_to_zed_free(stripe_customer_id)
1197 .await?;
1198 }
1199 }
1200 }
1201
1202 Ok(billing_customer)
1203}
1204
1205async fn handle_customer_subscription_event(
1206 app: &Arc<AppState>,
1207 rpc_server: &Arc<Server>,
1208 stripe_client: &stripe::Client,
1209 event: stripe::Event,
1210) -> anyhow::Result<()> {
1211 let EventObject::Subscription(subscription) = event.data.object else {
1212 bail!("unexpected event payload for {}", event.id);
1213 };
1214
1215 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1216
1217 let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1218
1219 // When the user's subscription changes, push down any changes to their plan.
1220 rpc_server
1221 .update_plan_for_user(billing_customer.user_id)
1222 .await
1223 .trace_err();
1224
1225 // When the user's subscription changes, we want to refresh their LLM tokens
1226 // to either grant/revoke access.
1227 rpc_server
1228 .refresh_llm_tokens_for_user(billing_customer.user_id)
1229 .await;
1230
1231 Ok(())
1232}
1233
1234#[derive(Debug, Deserialize)]
1235struct GetCurrentUsageParams {
1236 github_user_id: i32,
1237}
1238
1239#[derive(Debug, Serialize)]
1240struct UsageCounts {
1241 pub used: i32,
1242 pub limit: Option<i32>,
1243 pub remaining: Option<i32>,
1244}
1245
1246#[derive(Debug, Serialize)]
1247struct ModelRequestUsage {
1248 pub model: String,
1249 pub mode: CompletionMode,
1250 pub requests: i32,
1251}
1252
1253#[derive(Debug, Serialize)]
1254struct CurrentUsage {
1255 pub model_requests: UsageCounts,
1256 pub model_request_usage: Vec<ModelRequestUsage>,
1257 pub edit_predictions: UsageCounts,
1258}
1259
1260#[derive(Debug, Default, Serialize)]
1261struct GetCurrentUsageResponse {
1262 pub plan: String,
1263 pub current_usage: Option<CurrentUsage>,
1264}
1265
1266async fn get_current_usage(
1267 Extension(app): Extension<Arc<AppState>>,
1268 Query(params): Query<GetCurrentUsageParams>,
1269) -> Result<Json<GetCurrentUsageResponse>> {
1270 let user = app
1271 .db
1272 .get_user_by_github_user_id(params.github_user_id)
1273 .await?
1274 .context("user not found")?;
1275
1276 let feature_flags = app.db.get_user_flags(user.id).await?;
1277 let has_extended_trial = feature_flags
1278 .iter()
1279 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1280
1281 let Some(llm_db) = app.llm_db.clone() else {
1282 return Err(Error::http(
1283 StatusCode::NOT_IMPLEMENTED,
1284 "LLM database not available".into(),
1285 ));
1286 };
1287
1288 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1289 return Ok(Json(GetCurrentUsageResponse::default()));
1290 };
1291
1292 let subscription_period = maybe!({
1293 let period_start_at = subscription.current_period_start_at()?;
1294 let period_end_at = subscription.current_period_end_at()?;
1295
1296 Some((period_start_at, period_end_at))
1297 });
1298
1299 let Some((period_start_at, period_end_at)) = subscription_period else {
1300 return Ok(Json(GetCurrentUsageResponse::default()));
1301 };
1302
1303 let usage = llm_db
1304 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1305 .await?;
1306
1307 let plan = subscription
1308 .kind
1309 .map(Into::into)
1310 .unwrap_or(zed_llm_client::Plan::ZedFree);
1311
1312 let model_requests_limit = match plan.model_requests_limit() {
1313 zed_llm_client::UsageLimit::Limited(limit) => {
1314 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1315 1_000
1316 } else {
1317 limit
1318 };
1319
1320 Some(limit)
1321 }
1322 zed_llm_client::UsageLimit::Unlimited => None,
1323 };
1324
1325 let edit_predictions_limit = match plan.edit_predictions_limit() {
1326 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1327 zed_llm_client::UsageLimit::Unlimited => None,
1328 };
1329
1330 let Some(usage) = usage else {
1331 return Ok(Json(GetCurrentUsageResponse {
1332 plan: plan.as_str().to_string(),
1333 current_usage: Some(CurrentUsage {
1334 model_requests: UsageCounts {
1335 used: 0,
1336 limit: model_requests_limit,
1337 remaining: model_requests_limit,
1338 },
1339 model_request_usage: Vec::new(),
1340 edit_predictions: UsageCounts {
1341 used: 0,
1342 limit: edit_predictions_limit,
1343 remaining: edit_predictions_limit,
1344 },
1345 }),
1346 }));
1347 };
1348
1349 let subscription_usage_meters = llm_db
1350 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1351 .await?;
1352
1353 let model_request_usage = subscription_usage_meters
1354 .into_iter()
1355 .filter_map(|(usage_meter, _usage)| {
1356 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1357
1358 Some(ModelRequestUsage {
1359 model: model.name.clone(),
1360 mode: usage_meter.mode,
1361 requests: usage_meter.requests,
1362 })
1363 })
1364 .collect::<Vec<_>>();
1365
1366 Ok(Json(GetCurrentUsageResponse {
1367 plan: plan.as_str().to_string(),
1368 current_usage: Some(CurrentUsage {
1369 model_requests: UsageCounts {
1370 used: usage.model_requests,
1371 limit: model_requests_limit,
1372 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1373 },
1374 model_request_usage,
1375 edit_predictions: UsageCounts {
1376 used: usage.edit_predictions,
1377 limit: edit_predictions_limit,
1378 remaining: edit_predictions_limit
1379 .map(|limit| (limit - usage.edit_predictions).max(0)),
1380 },
1381 }),
1382 }))
1383}
1384
1385impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1386 fn from(value: SubscriptionStatus) -> Self {
1387 match value {
1388 SubscriptionStatus::Incomplete => Self::Incomplete,
1389 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1390 SubscriptionStatus::Trialing => Self::Trialing,
1391 SubscriptionStatus::Active => Self::Active,
1392 SubscriptionStatus::PastDue => Self::PastDue,
1393 SubscriptionStatus::Canceled => Self::Canceled,
1394 SubscriptionStatus::Unpaid => Self::Unpaid,
1395 SubscriptionStatus::Paused => Self::Paused,
1396 }
1397 }
1398}
1399
1400impl From<CancellationDetailsReason> for StripeCancellationReason {
1401 fn from(value: CancellationDetailsReason) -> Self {
1402 match value {
1403 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1404 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1405 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1406 }
1407 }
1408}
1409
1410/// Finds or creates a billing customer using the provided customer.
1411pub async fn find_or_create_billing_customer(
1412 app: &Arc<AppState>,
1413 stripe_client: &stripe::Client,
1414 customer_or_id: Expandable<Customer>,
1415) -> anyhow::Result<Option<billing_customer::Model>> {
1416 let customer_id = match &customer_or_id {
1417 Expandable::Id(id) => id,
1418 Expandable::Object(customer) => customer.id.as_ref(),
1419 };
1420
1421 // If we already have a billing customer record associated with the Stripe customer,
1422 // there's nothing more we need to do.
1423 if let Some(billing_customer) = app
1424 .db
1425 .get_billing_customer_by_stripe_customer_id(customer_id)
1426 .await?
1427 {
1428 return Ok(Some(billing_customer));
1429 }
1430
1431 // If all we have is a customer ID, resolve it to a full customer record by
1432 // hitting the Stripe API.
1433 let customer = match customer_or_id {
1434 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1435 Expandable::Object(customer) => *customer,
1436 };
1437
1438 let Some(email) = customer.email else {
1439 return Ok(None);
1440 };
1441
1442 let Some(user) = app.db.get_user_by_email(&email).await? else {
1443 return Ok(None);
1444 };
1445
1446 let billing_customer = app
1447 .db
1448 .create_billing_customer(&CreateBillingCustomerParams {
1449 user_id: user.id,
1450 stripe_customer_id: customer.id.to_string(),
1451 })
1452 .await?;
1453
1454 Ok(Some(billing_customer))
1455}
1456
1457const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1458
1459pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1460 let Some(stripe_billing) = app.stripe_billing.clone() else {
1461 log::warn!("failed to retrieve Stripe billing object");
1462 return;
1463 };
1464 let Some(llm_db) = app.llm_db.clone() else {
1465 log::warn!("failed to retrieve LLM database");
1466 return;
1467 };
1468
1469 let executor = app.executor.clone();
1470 executor.spawn_detached({
1471 let executor = executor.clone();
1472 async move {
1473 loop {
1474 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1475 .await
1476 .context("failed to sync LLM request usage to Stripe")
1477 .trace_err();
1478 executor
1479 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1480 .await;
1481 }
1482 }
1483 });
1484}
1485
1486async fn sync_model_request_usage_with_stripe(
1487 app: &Arc<AppState>,
1488 llm_db: &Arc<LlmDatabase>,
1489 stripe_billing: &Arc<StripeBilling>,
1490) -> anyhow::Result<()> {
1491 let staff_users = app.db.get_staff_users().await?;
1492 let staff_user_ids = staff_users
1493 .iter()
1494 .map(|user| user.id)
1495 .collect::<HashSet<UserId>>();
1496
1497 let usage_meters = llm_db
1498 .get_current_subscription_usage_meters(Utc::now())
1499 .await?;
1500 let usage_meters = usage_meters
1501 .into_iter()
1502 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1503 .collect::<Vec<_>>();
1504 let user_ids = usage_meters
1505 .iter()
1506 .map(|(_, usage)| usage.user_id)
1507 .collect::<HashSet<UserId>>();
1508 let billing_subscriptions = app
1509 .db
1510 .get_active_zed_pro_billing_subscriptions(user_ids)
1511 .await?;
1512
1513 let claude_sonnet_4 = stripe_billing
1514 .find_price_by_lookup_key("claude-sonnet-4-requests")
1515 .await?;
1516 let claude_sonnet_4_max = stripe_billing
1517 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1518 .await?;
1519 let claude_opus_4 = stripe_billing
1520 .find_price_by_lookup_key("claude-opus-4-requests")
1521 .await?;
1522 let claude_opus_4_max = stripe_billing
1523 .find_price_by_lookup_key("claude-opus-4-requests-max")
1524 .await?;
1525 let claude_3_5_sonnet = stripe_billing
1526 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1527 .await?;
1528 let claude_3_7_sonnet = stripe_billing
1529 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1530 .await?;
1531 let claude_3_7_sonnet_max = stripe_billing
1532 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1533 .await?;
1534
1535 for (usage_meter, usage) in usage_meters {
1536 maybe!(async {
1537 let Some((billing_customer, billing_subscription)) =
1538 billing_subscriptions.get(&usage.user_id)
1539 else {
1540 bail!(
1541 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1542 usage.user_id
1543 );
1544 };
1545
1546 let stripe_customer_id = billing_customer
1547 .stripe_customer_id
1548 .parse::<stripe::CustomerId>()
1549 .context("failed to parse Stripe customer ID from database")?;
1550 let stripe_subscription_id = billing_subscription
1551 .stripe_subscription_id
1552 .parse::<stripe::SubscriptionId>()
1553 .context("failed to parse Stripe subscription ID from database")?;
1554
1555 let model = llm_db.model_by_id(usage_meter.model_id)?;
1556
1557 let (price, meter_event_name) = match model.name.as_str() {
1558 "claude-opus-4" => match usage_meter.mode {
1559 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1560 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1561 },
1562 "claude-sonnet-4" => match usage_meter.mode {
1563 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1564 CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
1565 },
1566 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1567 "claude-3-7-sonnet" => match usage_meter.mode {
1568 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1569 CompletionMode::Max => {
1570 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1571 }
1572 },
1573 model_name => {
1574 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1575 }
1576 };
1577
1578 stripe_billing
1579 .subscribe_to_price(&stripe_subscription_id, price)
1580 .await?;
1581 stripe_billing
1582 .bill_model_request_usage(
1583 &stripe_customer_id,
1584 meter_event_name,
1585 usage_meter.requests,
1586 )
1587 .await?;
1588
1589 Ok(())
1590 })
1591 .await
1592 .log_err();
1593 }
1594
1595 Ok(())
1596}