1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
21 Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
31use crate::rpc::{ResultExt as _, Server};
32use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
33use crate::{AppState, Error, Result};
34use crate::{db::UserId, llm::db::LlmDatabase};
35use crate::{
36 db::{
37 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
38 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
39 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
40 },
41 stripe_billing::StripeBilling,
42};
43
44pub fn router() -> Router {
45 Router::new()
46 .route(
47 "/billing/preferences",
48 get(get_billing_preferences).put(update_billing_preferences),
49 )
50 .route(
51 "/billing/subscriptions",
52 get(list_billing_subscriptions).post(create_billing_subscription),
53 )
54 .route(
55 "/billing/subscriptions/manage",
56 post(manage_billing_subscription),
57 )
58 .route(
59 "/billing/subscriptions/migrate",
60 post(migrate_to_new_billing),
61 )
62 .route(
63 "/billing/subscriptions/sync",
64 post(sync_billing_subscription),
65 )
66 .route("/billing/usage", get(get_current_usage))
67}
68
69#[derive(Debug, Deserialize)]
70struct GetBillingPreferencesParams {
71 github_user_id: i32,
72}
73
74#[derive(Debug, Serialize)]
75struct BillingPreferencesResponse {
76 trial_started_at: Option<String>,
77 max_monthly_llm_usage_spending_in_cents: i32,
78 model_request_overages_enabled: bool,
79 model_request_overages_spend_limit_in_cents: i32,
80}
81
82async fn get_billing_preferences(
83 Extension(app): Extension<Arc<AppState>>,
84 Query(params): Query<GetBillingPreferencesParams>,
85) -> Result<Json<BillingPreferencesResponse>> {
86 let user = app
87 .db
88 .get_user_by_github_user_id(params.github_user_id)
89 .await?
90 .context("user not found")?;
91
92 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
93 let preferences = app.db.get_billing_preferences(user.id).await?;
94
95 Ok(Json(BillingPreferencesResponse {
96 trial_started_at: billing_customer
97 .and_then(|billing_customer| billing_customer.trial_started_at)
98 .map(|trial_started_at| {
99 trial_started_at
100 .and_utc()
101 .to_rfc3339_opts(SecondsFormat::Millis, true)
102 }),
103 max_monthly_llm_usage_spending_in_cents: preferences
104 .as_ref()
105 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
106 preferences.max_monthly_llm_usage_spending_in_cents
107 }),
108 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
109 preferences.model_request_overages_enabled
110 }),
111 model_request_overages_spend_limit_in_cents: preferences
112 .as_ref()
113 .map_or(0, |preferences| {
114 preferences.model_request_overages_spend_limit_in_cents
115 }),
116 }))
117}
118
119#[derive(Debug, Deserialize)]
120struct UpdateBillingPreferencesBody {
121 github_user_id: i32,
122 #[serde(default)]
123 max_monthly_llm_usage_spending_in_cents: i32,
124 #[serde(default)]
125 model_request_overages_enabled: bool,
126 #[serde(default)]
127 model_request_overages_spend_limit_in_cents: i32,
128}
129
130async fn update_billing_preferences(
131 Extension(app): Extension<Arc<AppState>>,
132 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
133 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
134) -> Result<Json<BillingPreferencesResponse>> {
135 let user = app
136 .db
137 .get_user_by_github_user_id(body.github_user_id)
138 .await?
139 .context("user not found")?;
140
141 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
142
143 let max_monthly_llm_usage_spending_in_cents =
144 body.max_monthly_llm_usage_spending_in_cents.max(0);
145 let model_request_overages_spend_limit_in_cents =
146 body.model_request_overages_spend_limit_in_cents.max(0);
147
148 let billing_preferences =
149 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
150 app.db
151 .update_billing_preferences(
152 user.id,
153 &UpdateBillingPreferencesParams {
154 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
155 max_monthly_llm_usage_spending_in_cents,
156 ),
157 model_request_overages_enabled: ActiveValue::set(
158 body.model_request_overages_enabled,
159 ),
160 model_request_overages_spend_limit_in_cents: ActiveValue::set(
161 model_request_overages_spend_limit_in_cents,
162 ),
163 },
164 )
165 .await?
166 } else {
167 app.db
168 .create_billing_preferences(
169 user.id,
170 &crate::db::CreateBillingPreferencesParams {
171 max_monthly_llm_usage_spending_in_cents,
172 model_request_overages_enabled: body.model_request_overages_enabled,
173 model_request_overages_spend_limit_in_cents,
174 },
175 )
176 .await?
177 };
178
179 SnowflakeRow::new(
180 "Billing Preferences Updated",
181 Some(user.metrics_id),
182 user.admin,
183 None,
184 json!({
185 "user_id": user.id,
186 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
187 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
188 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
189 }),
190 )
191 .write(&app.kinesis_client, &app.config.kinesis_stream)
192 .await
193 .log_err();
194
195 rpc_server.refresh_llm_tokens_for_user(user.id).await;
196
197 Ok(Json(BillingPreferencesResponse {
198 trial_started_at: billing_customer
199 .and_then(|billing_customer| billing_customer.trial_started_at)
200 .map(|trial_started_at| {
201 trial_started_at
202 .and_utc()
203 .to_rfc3339_opts(SecondsFormat::Millis, true)
204 }),
205 max_monthly_llm_usage_spending_in_cents: billing_preferences
206 .max_monthly_llm_usage_spending_in_cents,
207 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
208 model_request_overages_spend_limit_in_cents: billing_preferences
209 .model_request_overages_spend_limit_in_cents,
210 }))
211}
212
213#[derive(Debug, Deserialize)]
214struct ListBillingSubscriptionsParams {
215 github_user_id: i32,
216}
217
218#[derive(Debug, Serialize)]
219struct BillingSubscriptionJson {
220 id: BillingSubscriptionId,
221 name: String,
222 status: StripeSubscriptionStatus,
223 trial_end_at: Option<String>,
224 cancel_at: Option<String>,
225 /// Whether this subscription can be canceled.
226 is_cancelable: bool,
227}
228
229#[derive(Debug, Serialize)]
230struct ListBillingSubscriptionsResponse {
231 subscriptions: Vec<BillingSubscriptionJson>,
232}
233
234async fn list_billing_subscriptions(
235 Extension(app): Extension<Arc<AppState>>,
236 Query(params): Query<ListBillingSubscriptionsParams>,
237) -> Result<Json<ListBillingSubscriptionsResponse>> {
238 let user = app
239 .db
240 .get_user_by_github_user_id(params.github_user_id)
241 .await?
242 .context("user not found")?;
243
244 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
245
246 Ok(Json(ListBillingSubscriptionsResponse {
247 subscriptions: subscriptions
248 .into_iter()
249 .map(|subscription| BillingSubscriptionJson {
250 id: subscription.id,
251 name: match subscription.kind {
252 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
253 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
254 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
255 None => "Zed LLM Usage".to_string(),
256 },
257 status: subscription.stripe_subscription_status,
258 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
259 maybe!({
260 let end_at = subscription.stripe_current_period_end?;
261 let end_at = DateTime::from_timestamp(end_at, 0)?;
262
263 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
264 })
265 } else {
266 None
267 },
268 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
269 cancel_at
270 .and_utc()
271 .to_rfc3339_opts(SecondsFormat::Millis, true)
272 }),
273 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
274 && subscription.stripe_subscription_status.is_cancelable()
275 && subscription.stripe_cancel_at.is_none(),
276 })
277 .collect(),
278 }))
279}
280
281#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
282#[serde(rename_all = "snake_case")]
283enum ProductCode {
284 ZedPro,
285 ZedProTrial,
286}
287
288#[derive(Debug, Deserialize)]
289struct CreateBillingSubscriptionBody {
290 github_user_id: i32,
291 product: ProductCode,
292}
293
294#[derive(Debug, Serialize)]
295struct CreateBillingSubscriptionResponse {
296 checkout_session_url: String,
297}
298
299/// Initiates a Stripe Checkout session for creating a billing subscription.
300async fn create_billing_subscription(
301 Extension(app): Extension<Arc<AppState>>,
302 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
303) -> Result<Json<CreateBillingSubscriptionResponse>> {
304 let user = app
305 .db
306 .get_user_by_github_user_id(body.github_user_id)
307 .await?
308 .context("user not found")?;
309
310 let Some(stripe_billing) = app.stripe_billing.clone() else {
311 log::error!("failed to retrieve Stripe billing object");
312 Err(Error::http(
313 StatusCode::NOT_IMPLEMENTED,
314 "not supported".into(),
315 ))?
316 };
317
318 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
319 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
320 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
321
322 if !is_checkout_allowed {
323 return Err(Error::http(
324 StatusCode::CONFLICT,
325 "user already has an active subscription".into(),
326 ));
327 }
328 }
329
330 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
331 if let Some(existing_billing_customer) = &existing_billing_customer {
332 if existing_billing_customer.has_overdue_invoices {
333 return Err(Error::http(
334 StatusCode::PAYMENT_REQUIRED,
335 "user has overdue invoices".into(),
336 ));
337 }
338 }
339
340 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
341 StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
342 } else {
343 stripe_billing
344 .find_or_create_customer_by_email(user.email_address.as_deref())
345 .await?
346 };
347
348 let success_url = format!(
349 "{}/account?checkout_complete=1",
350 app.config.zed_dot_dev_url()
351 );
352
353 let checkout_session_url = match body.product {
354 ProductCode::ZedPro => {
355 stripe_billing
356 .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
357 .await?
358 }
359 ProductCode::ZedProTrial => {
360 if let Some(existing_billing_customer) = &existing_billing_customer {
361 if existing_billing_customer.trial_started_at.is_some() {
362 return Err(Error::http(
363 StatusCode::FORBIDDEN,
364 "user already used free trial".into(),
365 ));
366 }
367 }
368
369 let feature_flags = app.db.get_user_flags(user.id).await?;
370
371 stripe_billing
372 .checkout_with_zed_pro_trial(
373 &customer_id,
374 &user.github_login,
375 feature_flags,
376 &success_url,
377 )
378 .await?
379 }
380 };
381
382 Ok(Json(CreateBillingSubscriptionResponse {
383 checkout_session_url,
384 }))
385}
386
387#[derive(Debug, PartialEq, Deserialize)]
388#[serde(rename_all = "snake_case")]
389enum ManageSubscriptionIntent {
390 /// The user intends to manage their subscription.
391 ///
392 /// This will open the Stripe billing portal without putting the user in a specific flow.
393 ManageSubscription,
394 /// The user intends to update their payment method.
395 UpdatePaymentMethod,
396 /// The user intends to upgrade to Zed Pro.
397 UpgradeToPro,
398 /// The user intends to cancel their subscription.
399 Cancel,
400 /// The user intends to stop the cancellation of their subscription.
401 StopCancellation,
402}
403
404#[derive(Debug, Deserialize)]
405struct ManageBillingSubscriptionBody {
406 github_user_id: i32,
407 intent: ManageSubscriptionIntent,
408 /// The ID of the subscription to manage.
409 subscription_id: BillingSubscriptionId,
410 redirect_to: Option<String>,
411}
412
413#[derive(Debug, Serialize)]
414struct ManageBillingSubscriptionResponse {
415 billing_portal_session_url: Option<String>,
416}
417
418/// Initiates a Stripe customer portal session for managing a billing subscription.
419async fn manage_billing_subscription(
420 Extension(app): Extension<Arc<AppState>>,
421 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
422) -> Result<Json<ManageBillingSubscriptionResponse>> {
423 let user = app
424 .db
425 .get_user_by_github_user_id(body.github_user_id)
426 .await?
427 .context("user not found")?;
428
429 let Some(stripe_client) = app.stripe_client.clone() else {
430 log::error!("failed to retrieve Stripe client");
431 Err(Error::http(
432 StatusCode::NOT_IMPLEMENTED,
433 "not supported".into(),
434 ))?
435 };
436
437 let Some(stripe_billing) = app.stripe_billing.clone() else {
438 log::error!("failed to retrieve Stripe billing object");
439 Err(Error::http(
440 StatusCode::NOT_IMPLEMENTED,
441 "not supported".into(),
442 ))?
443 };
444
445 let customer = app
446 .db
447 .get_billing_customer_by_user_id(user.id)
448 .await?
449 .context("billing customer not found")?;
450 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
451 .context("failed to parse customer ID")?;
452
453 let subscription = app
454 .db
455 .get_billing_subscription_by_id(body.subscription_id)
456 .await?
457 .context("subscription not found")?;
458 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
459 .context("failed to parse subscription ID")?;
460
461 if body.intent == ManageSubscriptionIntent::StopCancellation {
462 let updated_stripe_subscription = Subscription::update(
463 &stripe_client,
464 &subscription_id,
465 stripe::UpdateSubscription {
466 cancel_at_period_end: Some(false),
467 ..Default::default()
468 },
469 )
470 .await?;
471
472 app.db
473 .update_billing_subscription(
474 subscription.id,
475 &UpdateBillingSubscriptionParams {
476 stripe_cancel_at: ActiveValue::set(
477 updated_stripe_subscription
478 .cancel_at
479 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
480 .map(|time| time.naive_utc()),
481 ),
482 ..Default::default()
483 },
484 )
485 .await?;
486
487 return Ok(Json(ManageBillingSubscriptionResponse {
488 billing_portal_session_url: None,
489 }));
490 }
491
492 let flow = match body.intent {
493 ManageSubscriptionIntent::ManageSubscription => None,
494 ManageSubscriptionIntent::UpgradeToPro => {
495 let zed_pro_price_id: stripe::PriceId =
496 stripe_billing.zed_pro_price_id().await?.try_into()?;
497 let zed_free_price_id: stripe::PriceId =
498 stripe_billing.zed_free_price_id().await?.try_into()?;
499
500 let stripe_subscription =
501 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
502
503 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
504 && stripe_subscription.items.data.iter().any(|item| {
505 item.price
506 .as_ref()
507 .map_or(false, |price| price.id == zed_pro_price_id)
508 });
509 if is_on_zed_pro_trial {
510 let payment_methods = PaymentMethod::list(
511 &stripe_client,
512 &stripe::ListPaymentMethods {
513 customer: Some(stripe_subscription.customer.id()),
514 ..Default::default()
515 },
516 )
517 .await?;
518
519 let has_payment_method = !payment_methods.data.is_empty();
520 if !has_payment_method {
521 return Err(Error::http(
522 StatusCode::BAD_REQUEST,
523 "missing payment method".into(),
524 ));
525 }
526
527 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
528 Subscription::update(
529 &stripe_client,
530 &stripe_subscription.id,
531 stripe::UpdateSubscription {
532 trial_end: Some(stripe::Scheduled::now()),
533 ..Default::default()
534 },
535 )
536 .await?;
537
538 return Ok(Json(ManageBillingSubscriptionResponse {
539 billing_portal_session_url: None,
540 }));
541 }
542
543 let subscription_item_to_update = stripe_subscription
544 .items
545 .data
546 .iter()
547 .find_map(|item| {
548 let price = item.price.as_ref()?;
549
550 if price.id == zed_free_price_id {
551 Some(item.id.clone())
552 } else {
553 None
554 }
555 })
556 .context("No subscription item to update")?;
557
558 Some(CreateBillingPortalSessionFlowData {
559 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
560 subscription_update_confirm: Some(
561 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
562 subscription: subscription.stripe_subscription_id,
563 items: vec![
564 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
565 id: subscription_item_to_update.to_string(),
566 price: Some(zed_pro_price_id.to_string()),
567 quantity: Some(1),
568 },
569 ],
570 discounts: None,
571 },
572 ),
573 ..Default::default()
574 })
575 }
576 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
577 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
578 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
579 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
580 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
581 return_url: format!(
582 "{}{path}",
583 app.config.zed_dot_dev_url(),
584 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
585 ),
586 }),
587 ..Default::default()
588 }),
589 ..Default::default()
590 }),
591 ManageSubscriptionIntent::Cancel => {
592 if subscription.kind == Some(SubscriptionKind::ZedFree) {
593 return Err(Error::http(
594 StatusCode::BAD_REQUEST,
595 "free subscription cannot be canceled".into(),
596 ));
597 }
598
599 Some(CreateBillingPortalSessionFlowData {
600 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
601 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
602 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
603 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
604 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
605 }),
606 ..Default::default()
607 }),
608 subscription_cancel: Some(
609 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
610 subscription: subscription.stripe_subscription_id,
611 retention: None,
612 },
613 ),
614 ..Default::default()
615 })
616 }
617 ManageSubscriptionIntent::StopCancellation => unreachable!(),
618 };
619
620 let mut params = CreateBillingPortalSession::new(customer_id);
621 params.flow_data = flow;
622 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
623 params.return_url = Some(&return_url);
624
625 let session = BillingPortalSession::create(&stripe_client, params).await?;
626
627 Ok(Json(ManageBillingSubscriptionResponse {
628 billing_portal_session_url: Some(session.url),
629 }))
630}
631
632#[derive(Debug, Deserialize)]
633struct MigrateToNewBillingBody {
634 github_user_id: i32,
635}
636
637#[derive(Debug, Serialize)]
638struct MigrateToNewBillingResponse {
639 /// The ID of the subscription that was canceled.
640 canceled_subscription_id: Option<String>,
641}
642
643async fn migrate_to_new_billing(
644 Extension(app): Extension<Arc<AppState>>,
645 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
646) -> Result<Json<MigrateToNewBillingResponse>> {
647 let Some(stripe_client) = app.stripe_client.clone() else {
648 log::error!("failed to retrieve Stripe client");
649 Err(Error::http(
650 StatusCode::NOT_IMPLEMENTED,
651 "not supported".into(),
652 ))?
653 };
654
655 let user = app
656 .db
657 .get_user_by_github_user_id(body.github_user_id)
658 .await?
659 .context("user not found")?;
660
661 let old_billing_subscriptions_by_user = app
662 .db
663 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
664 .await?;
665
666 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
667 old_billing_subscriptions_by_user.get(&user.id)
668 {
669 let stripe_subscription_id = billing_subscription
670 .stripe_subscription_id
671 .parse::<stripe::SubscriptionId>()
672 .context("failed to parse Stripe subscription ID from database")?;
673
674 Subscription::cancel(
675 &stripe_client,
676 &stripe_subscription_id,
677 stripe::CancelSubscription {
678 invoice_now: Some(true),
679 ..Default::default()
680 },
681 )
682 .await?;
683
684 Some(stripe_subscription_id)
685 } else {
686 None
687 };
688
689 let all_feature_flags = app.db.list_feature_flags().await?;
690 let user_feature_flags = app.db.get_user_flags(user.id).await?;
691
692 for feature_flag in ["new-billing", "assistant2"] {
693 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
694 if already_in_feature_flag {
695 continue;
696 }
697
698 let feature_flag = all_feature_flags
699 .iter()
700 .find(|flag| flag.flag == feature_flag)
701 .context("failed to find feature flag: {feature_flag:?}")?;
702
703 app.db.add_user_flag(user.id, feature_flag.id).await?;
704 }
705
706 Ok(Json(MigrateToNewBillingResponse {
707 canceled_subscription_id: canceled_subscription_id
708 .map(|subscription_id| subscription_id.to_string()),
709 }))
710}
711
712#[derive(Debug, Deserialize)]
713struct SyncBillingSubscriptionBody {
714 github_user_id: i32,
715}
716
717#[derive(Debug, Serialize)]
718struct SyncBillingSubscriptionResponse {
719 stripe_customer_id: String,
720}
721
722async fn sync_billing_subscription(
723 Extension(app): Extension<Arc<AppState>>,
724 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
725) -> Result<Json<SyncBillingSubscriptionResponse>> {
726 let Some(stripe_client) = app.stripe_client.clone() else {
727 log::error!("failed to retrieve Stripe client");
728 Err(Error::http(
729 StatusCode::NOT_IMPLEMENTED,
730 "not supported".into(),
731 ))?
732 };
733
734 let user = app
735 .db
736 .get_user_by_github_user_id(body.github_user_id)
737 .await?
738 .context("user not found")?;
739
740 let billing_customer = app
741 .db
742 .get_billing_customer_by_user_id(user.id)
743 .await?
744 .context("billing customer not found")?;
745 let stripe_customer_id = billing_customer
746 .stripe_customer_id
747 .parse::<stripe::CustomerId>()
748 .context("failed to parse Stripe customer ID from database")?;
749
750 let subscriptions = Subscription::list(
751 &stripe_client,
752 &stripe::ListSubscriptions {
753 customer: Some(stripe_customer_id),
754 // Sync all non-canceled subscriptions.
755 status: None,
756 ..Default::default()
757 },
758 )
759 .await?;
760
761 for subscription in subscriptions.data {
762 let subscription_id = subscription.id.clone();
763
764 sync_subscription(&app, &stripe_client, subscription)
765 .await
766 .with_context(|| {
767 format!(
768 "failed to sync subscription {subscription_id} for user {}",
769 user.id,
770 )
771 })?;
772 }
773
774 Ok(Json(SyncBillingSubscriptionResponse {
775 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
776 }))
777}
778
779/// The amount of time we wait in between each poll of Stripe events.
780///
781/// This value should strike a balance between:
782/// 1. Being short enough that we update quickly when something in Stripe changes
783/// 2. Being long enough that we don't eat into our rate limits.
784///
785/// As a point of reference, the Sequin folks say they have this at **500ms**:
786///
787/// > We poll the Stripe /events endpoint every 500ms per account
788/// >
789/// > — https://blog.sequinstream.com/events-not-webhooks/
790const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
791
792/// The maximum number of events to return per page.
793///
794/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
795///
796/// > Limit can range between 1 and 100, and the default is 10.
797const EVENTS_LIMIT_PER_PAGE: u64 = 100;
798
799/// The number of pages consisting entirely of already-processed events that we
800/// will see before we stop retrieving events.
801///
802/// This is used to prevent over-fetching the Stripe events API for events we've
803/// already seen and processed.
804const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
805
806/// Polls the Stripe events API periodically to reconcile the records in our
807/// database with the data in Stripe.
808pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
809 let Some(stripe_client) = app.stripe_client.clone() else {
810 log::warn!("failed to retrieve Stripe client");
811 return;
812 };
813
814 let executor = app.executor.clone();
815 executor.spawn_detached({
816 let executor = executor.clone();
817 async move {
818 loop {
819 poll_stripe_events(&app, &rpc_server, &stripe_client)
820 .await
821 .log_err();
822
823 executor.sleep(POLL_EVENTS_INTERVAL).await;
824 }
825 }
826 });
827}
828
829async fn poll_stripe_events(
830 app: &Arc<AppState>,
831 rpc_server: &Arc<Server>,
832 stripe_client: &stripe::Client,
833) -> anyhow::Result<()> {
834 fn event_type_to_string(event_type: EventType) -> String {
835 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
836 // so we need to unquote it.
837 event_type.to_string().trim_matches('"').to_string()
838 }
839
840 let event_types = [
841 EventType::CustomerCreated,
842 EventType::CustomerUpdated,
843 EventType::CustomerSubscriptionCreated,
844 EventType::CustomerSubscriptionUpdated,
845 EventType::CustomerSubscriptionPaused,
846 EventType::CustomerSubscriptionResumed,
847 EventType::CustomerSubscriptionDeleted,
848 ]
849 .into_iter()
850 .map(event_type_to_string)
851 .collect::<Vec<_>>();
852
853 let mut pages_of_already_processed_events = 0;
854 let mut unprocessed_events = Vec::new();
855
856 log::info!(
857 "Stripe events: starting retrieval for {}",
858 event_types.join(", ")
859 );
860 let mut params = ListEvents::new();
861 params.types = Some(event_types.clone());
862 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
863
864 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
865 .await?
866 .paginate(params);
867
868 loop {
869 let processed_event_ids = {
870 let event_ids = event_pages
871 .page
872 .data
873 .iter()
874 .map(|event| event.id.as_str())
875 .collect::<Vec<_>>();
876 app.db
877 .get_processed_stripe_events_by_event_ids(&event_ids)
878 .await?
879 .into_iter()
880 .map(|event| event.stripe_event_id)
881 .collect::<Vec<_>>()
882 };
883
884 let mut processed_events_in_page = 0;
885 let events_in_page = event_pages.page.data.len();
886 for event in &event_pages.page.data {
887 if processed_event_ids.contains(&event.id.to_string()) {
888 processed_events_in_page += 1;
889 log::debug!("Stripe events: already processed '{}', skipping", event.id);
890 } else {
891 unprocessed_events.push(event.clone());
892 }
893 }
894
895 if processed_events_in_page == events_in_page {
896 pages_of_already_processed_events += 1;
897 }
898
899 if event_pages.page.has_more {
900 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
901 {
902 log::info!(
903 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
904 );
905 break;
906 } else {
907 log::info!("Stripe events: retrieving next page");
908 event_pages = event_pages.next(&stripe_client).await?;
909 }
910 } else {
911 break;
912 }
913 }
914
915 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
916
917 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
918 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
919
920 for event in unprocessed_events {
921 let event_id = event.id.clone();
922 let processed_event_params = CreateProcessedStripeEventParams {
923 stripe_event_id: event.id.to_string(),
924 stripe_event_type: event_type_to_string(event.type_),
925 stripe_event_created_timestamp: event.created,
926 };
927
928 // If the event has happened too far in the past, we don't want to
929 // process it and risk overwriting other more-recent updates.
930 //
931 // 1 day was chosen arbitrarily. This could be made longer or shorter.
932 let one_day = Duration::from_secs(24 * 60 * 60);
933 let a_day_ago = Utc::now() - one_day;
934 if a_day_ago.timestamp() > event.created {
935 log::info!(
936 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
937 event_id
938 );
939 app.db
940 .create_processed_stripe_event(&processed_event_params)
941 .await?;
942
943 continue;
944 }
945
946 let process_result = match event.type_ {
947 EventType::CustomerCreated | EventType::CustomerUpdated => {
948 handle_customer_event(app, stripe_client, event).await
949 }
950 EventType::CustomerSubscriptionCreated
951 | EventType::CustomerSubscriptionUpdated
952 | EventType::CustomerSubscriptionPaused
953 | EventType::CustomerSubscriptionResumed
954 | EventType::CustomerSubscriptionDeleted => {
955 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
956 }
957 _ => Ok(()),
958 };
959
960 if let Some(()) = process_result
961 .with_context(|| format!("failed to process event {event_id} successfully"))
962 .log_err()
963 {
964 app.db
965 .create_processed_stripe_event(&processed_event_params)
966 .await?;
967 }
968 }
969
970 Ok(())
971}
972
973async fn handle_customer_event(
974 app: &Arc<AppState>,
975 _stripe_client: &stripe::Client,
976 event: stripe::Event,
977) -> anyhow::Result<()> {
978 let EventObject::Customer(customer) = event.data.object else {
979 bail!("unexpected event payload for {}", event.id);
980 };
981
982 log::info!("handling Stripe {} event: {}", event.type_, event.id);
983
984 let Some(email) = customer.email else {
985 log::info!("Stripe customer has no email: skipping");
986 return Ok(());
987 };
988
989 let Some(user) = app.db.get_user_by_email(&email).await? else {
990 log::info!("no user found for email: skipping");
991 return Ok(());
992 };
993
994 if let Some(existing_customer) = app
995 .db
996 .get_billing_customer_by_stripe_customer_id(&customer.id)
997 .await?
998 {
999 app.db
1000 .update_billing_customer(
1001 existing_customer.id,
1002 &UpdateBillingCustomerParams {
1003 // For now we just leave the information as-is, as it is not
1004 // likely to change.
1005 ..Default::default()
1006 },
1007 )
1008 .await?;
1009 } else {
1010 app.db
1011 .create_billing_customer(&CreateBillingCustomerParams {
1012 user_id: user.id,
1013 stripe_customer_id: customer.id.to_string(),
1014 })
1015 .await?;
1016 }
1017
1018 Ok(())
1019}
1020
1021async fn sync_subscription(
1022 app: &Arc<AppState>,
1023 stripe_client: &stripe::Client,
1024 subscription: stripe::Subscription,
1025) -> anyhow::Result<billing_customer::Model> {
1026 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1027 stripe_billing
1028 .determine_subscription_kind(&subscription)
1029 .await
1030 } else {
1031 None
1032 };
1033
1034 let billing_customer =
1035 find_or_create_billing_customer(app, stripe_client, subscription.customer)
1036 .await?
1037 .context("billing customer not found")?;
1038
1039 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1040 if subscription.status == SubscriptionStatus::Trialing {
1041 let current_period_start =
1042 DateTime::from_timestamp(subscription.current_period_start, 0)
1043 .context("No trial subscription period start")?;
1044
1045 app.db
1046 .update_billing_customer(
1047 billing_customer.id,
1048 &UpdateBillingCustomerParams {
1049 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1050 ..Default::default()
1051 },
1052 )
1053 .await?;
1054 }
1055 }
1056
1057 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1058 && subscription
1059 .cancellation_details
1060 .as_ref()
1061 .and_then(|details| details.reason)
1062 .map_or(false, |reason| {
1063 reason == CancellationDetailsReason::PaymentFailed
1064 });
1065
1066 if was_canceled_due_to_payment_failure {
1067 app.db
1068 .update_billing_customer(
1069 billing_customer.id,
1070 &UpdateBillingCustomerParams {
1071 has_overdue_invoices: ActiveValue::set(true),
1072 ..Default::default()
1073 },
1074 )
1075 .await?;
1076 }
1077
1078 if let Some(existing_subscription) = app
1079 .db
1080 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1081 .await?
1082 {
1083 app.db
1084 .update_billing_subscription(
1085 existing_subscription.id,
1086 &UpdateBillingSubscriptionParams {
1087 billing_customer_id: ActiveValue::set(billing_customer.id),
1088 kind: ActiveValue::set(subscription_kind),
1089 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1090 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1091 stripe_cancel_at: ActiveValue::set(
1092 subscription
1093 .cancel_at
1094 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1095 .map(|time| time.naive_utc()),
1096 ),
1097 stripe_cancellation_reason: ActiveValue::set(
1098 subscription
1099 .cancellation_details
1100 .and_then(|details| details.reason)
1101 .map(|reason| reason.into()),
1102 ),
1103 stripe_current_period_start: ActiveValue::set(Some(
1104 subscription.current_period_start,
1105 )),
1106 stripe_current_period_end: ActiveValue::set(Some(
1107 subscription.current_period_end,
1108 )),
1109 },
1110 )
1111 .await?;
1112 } else {
1113 if let Some(existing_subscription) = app
1114 .db
1115 .get_active_billing_subscription(billing_customer.user_id)
1116 .await?
1117 {
1118 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1119 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1120 {
1121 let stripe_subscription_id = existing_subscription
1122 .stripe_subscription_id
1123 .parse::<stripe::SubscriptionId>()
1124 .context("failed to parse Stripe subscription ID from database")?;
1125
1126 Subscription::cancel(
1127 &stripe_client,
1128 &stripe_subscription_id,
1129 stripe::CancelSubscription {
1130 invoice_now: None,
1131 ..Default::default()
1132 },
1133 )
1134 .await?;
1135 } else {
1136 // If the user already has an active billing subscription, ignore the
1137 // event and return an `Ok` to signal that it was processed
1138 // successfully.
1139 //
1140 // There is the possibility that this could cause us to not create a
1141 // subscription in the following scenario:
1142 //
1143 // 1. User has an active subscription A
1144 // 2. User cancels subscription A
1145 // 3. User creates a new subscription B
1146 // 4. We process the new subscription B before the cancellation of subscription A
1147 // 5. User ends up with no subscriptions
1148 //
1149 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1150
1151 log::info!(
1152 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1153 user_id = billing_customer.user_id,
1154 subscription_id = subscription.id
1155 );
1156 return Ok(billing_customer);
1157 }
1158 }
1159
1160 app.db
1161 .create_billing_subscription(&CreateBillingSubscriptionParams {
1162 billing_customer_id: billing_customer.id,
1163 kind: subscription_kind,
1164 stripe_subscription_id: subscription.id.to_string(),
1165 stripe_subscription_status: subscription.status.into(),
1166 stripe_cancellation_reason: subscription
1167 .cancellation_details
1168 .and_then(|details| details.reason)
1169 .map(|reason| reason.into()),
1170 stripe_current_period_start: Some(subscription.current_period_start),
1171 stripe_current_period_end: Some(subscription.current_period_end),
1172 })
1173 .await?;
1174 }
1175
1176 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1177 if subscription.status == SubscriptionStatus::Canceled
1178 || subscription.status == SubscriptionStatus::Paused
1179 {
1180 let already_has_active_billing_subscription = app
1181 .db
1182 .has_active_billing_subscription(billing_customer.user_id)
1183 .await?;
1184 if !already_has_active_billing_subscription {
1185 let stripe_customer_id =
1186 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1187
1188 stripe_billing
1189 .subscribe_to_zed_free(stripe_customer_id)
1190 .await?;
1191 }
1192 }
1193 }
1194
1195 Ok(billing_customer)
1196}
1197
1198async fn handle_customer_subscription_event(
1199 app: &Arc<AppState>,
1200 rpc_server: &Arc<Server>,
1201 stripe_client: &stripe::Client,
1202 event: stripe::Event,
1203) -> anyhow::Result<()> {
1204 let EventObject::Subscription(subscription) = event.data.object else {
1205 bail!("unexpected event payload for {}", event.id);
1206 };
1207
1208 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1209
1210 let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1211
1212 // When the user's subscription changes, push down any changes to their plan.
1213 rpc_server
1214 .update_plan_for_user(billing_customer.user_id)
1215 .await
1216 .trace_err();
1217
1218 // When the user's subscription changes, we want to refresh their LLM tokens
1219 // to either grant/revoke access.
1220 rpc_server
1221 .refresh_llm_tokens_for_user(billing_customer.user_id)
1222 .await;
1223
1224 Ok(())
1225}
1226
1227#[derive(Debug, Deserialize)]
1228struct GetCurrentUsageParams {
1229 github_user_id: i32,
1230}
1231
1232#[derive(Debug, Serialize)]
1233struct UsageCounts {
1234 pub used: i32,
1235 pub limit: Option<i32>,
1236 pub remaining: Option<i32>,
1237}
1238
1239#[derive(Debug, Serialize)]
1240struct ModelRequestUsage {
1241 pub model: String,
1242 pub mode: CompletionMode,
1243 pub requests: i32,
1244}
1245
1246#[derive(Debug, Serialize)]
1247struct CurrentUsage {
1248 pub model_requests: UsageCounts,
1249 pub model_request_usage: Vec<ModelRequestUsage>,
1250 pub edit_predictions: UsageCounts,
1251}
1252
1253#[derive(Debug, Default, Serialize)]
1254struct GetCurrentUsageResponse {
1255 pub plan: String,
1256 pub current_usage: Option<CurrentUsage>,
1257}
1258
1259async fn get_current_usage(
1260 Extension(app): Extension<Arc<AppState>>,
1261 Query(params): Query<GetCurrentUsageParams>,
1262) -> Result<Json<GetCurrentUsageResponse>> {
1263 let user = app
1264 .db
1265 .get_user_by_github_user_id(params.github_user_id)
1266 .await?
1267 .context("user not found")?;
1268
1269 let feature_flags = app.db.get_user_flags(user.id).await?;
1270 let has_extended_trial = feature_flags
1271 .iter()
1272 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1273
1274 let Some(llm_db) = app.llm_db.clone() else {
1275 return Err(Error::http(
1276 StatusCode::NOT_IMPLEMENTED,
1277 "LLM database not available".into(),
1278 ));
1279 };
1280
1281 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1282 return Ok(Json(GetCurrentUsageResponse::default()));
1283 };
1284
1285 let subscription_period = maybe!({
1286 let period_start_at = subscription.current_period_start_at()?;
1287 let period_end_at = subscription.current_period_end_at()?;
1288
1289 Some((period_start_at, period_end_at))
1290 });
1291
1292 let Some((period_start_at, period_end_at)) = subscription_period else {
1293 return Ok(Json(GetCurrentUsageResponse::default()));
1294 };
1295
1296 let usage = llm_db
1297 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1298 .await?;
1299
1300 let plan = subscription
1301 .kind
1302 .map(Into::into)
1303 .unwrap_or(zed_llm_client::Plan::ZedFree);
1304
1305 let model_requests_limit = match plan.model_requests_limit() {
1306 zed_llm_client::UsageLimit::Limited(limit) => {
1307 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1308 1_000
1309 } else {
1310 limit
1311 };
1312
1313 Some(limit)
1314 }
1315 zed_llm_client::UsageLimit::Unlimited => None,
1316 };
1317
1318 let edit_predictions_limit = match plan.edit_predictions_limit() {
1319 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1320 zed_llm_client::UsageLimit::Unlimited => None,
1321 };
1322
1323 let Some(usage) = usage else {
1324 return Ok(Json(GetCurrentUsageResponse {
1325 plan: plan.as_str().to_string(),
1326 current_usage: Some(CurrentUsage {
1327 model_requests: UsageCounts {
1328 used: 0,
1329 limit: model_requests_limit,
1330 remaining: model_requests_limit,
1331 },
1332 model_request_usage: Vec::new(),
1333 edit_predictions: UsageCounts {
1334 used: 0,
1335 limit: edit_predictions_limit,
1336 remaining: edit_predictions_limit,
1337 },
1338 }),
1339 }));
1340 };
1341
1342 let subscription_usage_meters = llm_db
1343 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1344 .await?;
1345
1346 let model_request_usage = subscription_usage_meters
1347 .into_iter()
1348 .filter_map(|(usage_meter, _usage)| {
1349 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1350
1351 Some(ModelRequestUsage {
1352 model: model.name.clone(),
1353 mode: usage_meter.mode,
1354 requests: usage_meter.requests,
1355 })
1356 })
1357 .collect::<Vec<_>>();
1358
1359 Ok(Json(GetCurrentUsageResponse {
1360 plan: plan.as_str().to_string(),
1361 current_usage: Some(CurrentUsage {
1362 model_requests: UsageCounts {
1363 used: usage.model_requests,
1364 limit: model_requests_limit,
1365 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1366 },
1367 model_request_usage,
1368 edit_predictions: UsageCounts {
1369 used: usage.edit_predictions,
1370 limit: edit_predictions_limit,
1371 remaining: edit_predictions_limit
1372 .map(|limit| (limit - usage.edit_predictions).max(0)),
1373 },
1374 }),
1375 }))
1376}
1377
1378impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1379 fn from(value: SubscriptionStatus) -> Self {
1380 match value {
1381 SubscriptionStatus::Incomplete => Self::Incomplete,
1382 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1383 SubscriptionStatus::Trialing => Self::Trialing,
1384 SubscriptionStatus::Active => Self::Active,
1385 SubscriptionStatus::PastDue => Self::PastDue,
1386 SubscriptionStatus::Canceled => Self::Canceled,
1387 SubscriptionStatus::Unpaid => Self::Unpaid,
1388 SubscriptionStatus::Paused => Self::Paused,
1389 }
1390 }
1391}
1392
1393impl From<CancellationDetailsReason> for StripeCancellationReason {
1394 fn from(value: CancellationDetailsReason) -> Self {
1395 match value {
1396 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1397 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1398 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1399 }
1400 }
1401}
1402
1403/// Finds or creates a billing customer using the provided customer.
1404pub async fn find_or_create_billing_customer(
1405 app: &Arc<AppState>,
1406 stripe_client: &stripe::Client,
1407 customer_or_id: Expandable<Customer>,
1408) -> anyhow::Result<Option<billing_customer::Model>> {
1409 let customer_id = match &customer_or_id {
1410 Expandable::Id(id) => id,
1411 Expandable::Object(customer) => customer.id.as_ref(),
1412 };
1413
1414 // If we already have a billing customer record associated with the Stripe customer,
1415 // there's nothing more we need to do.
1416 if let Some(billing_customer) = app
1417 .db
1418 .get_billing_customer_by_stripe_customer_id(customer_id)
1419 .await?
1420 {
1421 return Ok(Some(billing_customer));
1422 }
1423
1424 // If all we have is a customer ID, resolve it to a full customer record by
1425 // hitting the Stripe API.
1426 let customer = match customer_or_id {
1427 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1428 Expandable::Object(customer) => *customer,
1429 };
1430
1431 let Some(email) = customer.email else {
1432 return Ok(None);
1433 };
1434
1435 let Some(user) = app.db.get_user_by_email(&email).await? else {
1436 return Ok(None);
1437 };
1438
1439 let billing_customer = app
1440 .db
1441 .create_billing_customer(&CreateBillingCustomerParams {
1442 user_id: user.id,
1443 stripe_customer_id: customer.id.to_string(),
1444 })
1445 .await?;
1446
1447 Ok(Some(billing_customer))
1448}
1449
1450const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1451
1452pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1453 let Some(stripe_billing) = app.stripe_billing.clone() else {
1454 log::warn!("failed to retrieve Stripe billing object");
1455 return;
1456 };
1457 let Some(llm_db) = app.llm_db.clone() else {
1458 log::warn!("failed to retrieve LLM database");
1459 return;
1460 };
1461
1462 let executor = app.executor.clone();
1463 executor.spawn_detached({
1464 let executor = executor.clone();
1465 async move {
1466 loop {
1467 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1468 .await
1469 .context("failed to sync LLM request usage to Stripe")
1470 .trace_err();
1471 executor
1472 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1473 .await;
1474 }
1475 }
1476 });
1477}
1478
1479async fn sync_model_request_usage_with_stripe(
1480 app: &Arc<AppState>,
1481 llm_db: &Arc<LlmDatabase>,
1482 stripe_billing: &Arc<StripeBilling>,
1483) -> anyhow::Result<()> {
1484 let staff_users = app.db.get_staff_users().await?;
1485 let staff_user_ids = staff_users
1486 .iter()
1487 .map(|user| user.id)
1488 .collect::<HashSet<UserId>>();
1489
1490 let usage_meters = llm_db
1491 .get_current_subscription_usage_meters(Utc::now())
1492 .await?;
1493 let usage_meters = usage_meters
1494 .into_iter()
1495 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1496 .collect::<Vec<_>>();
1497 let user_ids = usage_meters
1498 .iter()
1499 .map(|(_, usage)| usage.user_id)
1500 .collect::<HashSet<UserId>>();
1501 let billing_subscriptions = app
1502 .db
1503 .get_active_zed_pro_billing_subscriptions(user_ids)
1504 .await?;
1505
1506 let claude_sonnet_4 = stripe_billing
1507 .find_price_by_lookup_key("claude-sonnet-4-requests")
1508 .await?;
1509 let claude_sonnet_4_max = stripe_billing
1510 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1511 .await?;
1512 let claude_opus_4 = stripe_billing
1513 .find_price_by_lookup_key("claude-opus-4-requests")
1514 .await?;
1515 let claude_opus_4_max = stripe_billing
1516 .find_price_by_lookup_key("claude-opus-4-requests-max")
1517 .await?;
1518 let claude_3_5_sonnet = stripe_billing
1519 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1520 .await?;
1521 let claude_3_7_sonnet = stripe_billing
1522 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1523 .await?;
1524 let claude_3_7_sonnet_max = stripe_billing
1525 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1526 .await?;
1527
1528 for (usage_meter, usage) in usage_meters {
1529 maybe!(async {
1530 let Some((billing_customer, billing_subscription)) =
1531 billing_subscriptions.get(&usage.user_id)
1532 else {
1533 bail!(
1534 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1535 usage.user_id
1536 );
1537 };
1538
1539 let stripe_customer_id =
1540 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1541 let stripe_subscription_id =
1542 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1543
1544 let model = llm_db.model_by_id(usage_meter.model_id)?;
1545
1546 let (price, meter_event_name) = match model.name.as_str() {
1547 "claude-opus-4" => match usage_meter.mode {
1548 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1549 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1550 },
1551 "claude-sonnet-4" => match usage_meter.mode {
1552 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1553 CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
1554 },
1555 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1556 "claude-3-7-sonnet" => match usage_meter.mode {
1557 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1558 CompletionMode::Max => {
1559 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1560 }
1561 },
1562 model_name => {
1563 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1564 }
1565 };
1566
1567 stripe_billing
1568 .subscribe_to_price(&stripe_subscription_id, price)
1569 .await?;
1570 stripe_billing
1571 .bill_model_request_usage(
1572 &stripe_customer_id,
1573 meter_event_name,
1574 usage_meter.requests,
1575 )
1576 .await?;
1577
1578 Ok(())
1579 })
1580 .await
1581 .log_err();
1582 }
1583
1584 Ok(())
1585}