1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
21 PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
31use crate::rpc::{ResultExt as _, Server};
32use crate::stripe_client::{
33 StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
34 StripeSubscriptionId, UpdateCustomerParams,
35};
36use crate::{AppState, Error, Result};
37use crate::{db::UserId, llm::db::LlmDatabase};
38use crate::{
39 db::{
40 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
41 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
42 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
43 },
44 stripe_billing::StripeBilling,
45};
46
47pub fn router() -> Router {
48 Router::new()
49 .route(
50 "/billing/preferences",
51 get(get_billing_preferences).put(update_billing_preferences),
52 )
53 .route(
54 "/billing/subscriptions",
55 get(list_billing_subscriptions).post(create_billing_subscription),
56 )
57 .route(
58 "/billing/subscriptions/manage",
59 post(manage_billing_subscription),
60 )
61 .route(
62 "/billing/subscriptions/sync",
63 post(sync_billing_subscription),
64 )
65 .route("/billing/usage", get(get_current_usage))
66}
67
68#[derive(Debug, Deserialize)]
69struct GetBillingPreferencesParams {
70 github_user_id: i32,
71}
72
73#[derive(Debug, Serialize)]
74struct BillingPreferencesResponse {
75 trial_started_at: Option<String>,
76 max_monthly_llm_usage_spending_in_cents: i32,
77 model_request_overages_enabled: bool,
78 model_request_overages_spend_limit_in_cents: i32,
79}
80
81async fn get_billing_preferences(
82 Extension(app): Extension<Arc<AppState>>,
83 Query(params): Query<GetBillingPreferencesParams>,
84) -> Result<Json<BillingPreferencesResponse>> {
85 let user = app
86 .db
87 .get_user_by_github_user_id(params.github_user_id)
88 .await?
89 .context("user not found")?;
90
91 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
92 let preferences = app.db.get_billing_preferences(user.id).await?;
93
94 Ok(Json(BillingPreferencesResponse {
95 trial_started_at: billing_customer
96 .and_then(|billing_customer| billing_customer.trial_started_at)
97 .map(|trial_started_at| {
98 trial_started_at
99 .and_utc()
100 .to_rfc3339_opts(SecondsFormat::Millis, true)
101 }),
102 max_monthly_llm_usage_spending_in_cents: preferences
103 .as_ref()
104 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
105 preferences.max_monthly_llm_usage_spending_in_cents
106 }),
107 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
108 preferences.model_request_overages_enabled
109 }),
110 model_request_overages_spend_limit_in_cents: preferences
111 .as_ref()
112 .map_or(0, |preferences| {
113 preferences.model_request_overages_spend_limit_in_cents
114 }),
115 }))
116}
117
118#[derive(Debug, Deserialize)]
119struct UpdateBillingPreferencesBody {
120 github_user_id: i32,
121 #[serde(default)]
122 max_monthly_llm_usage_spending_in_cents: i32,
123 #[serde(default)]
124 model_request_overages_enabled: bool,
125 #[serde(default)]
126 model_request_overages_spend_limit_in_cents: i32,
127}
128
129async fn update_billing_preferences(
130 Extension(app): Extension<Arc<AppState>>,
131 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
132 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
133) -> Result<Json<BillingPreferencesResponse>> {
134 let user = app
135 .db
136 .get_user_by_github_user_id(body.github_user_id)
137 .await?
138 .context("user not found")?;
139
140 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
141
142 let max_monthly_llm_usage_spending_in_cents =
143 body.max_monthly_llm_usage_spending_in_cents.max(0);
144 let model_request_overages_spend_limit_in_cents =
145 body.model_request_overages_spend_limit_in_cents.max(0);
146
147 let billing_preferences =
148 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
149 app.db
150 .update_billing_preferences(
151 user.id,
152 &UpdateBillingPreferencesParams {
153 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
154 max_monthly_llm_usage_spending_in_cents,
155 ),
156 model_request_overages_enabled: ActiveValue::set(
157 body.model_request_overages_enabled,
158 ),
159 model_request_overages_spend_limit_in_cents: ActiveValue::set(
160 model_request_overages_spend_limit_in_cents,
161 ),
162 },
163 )
164 .await?
165 } else {
166 app.db
167 .create_billing_preferences(
168 user.id,
169 &crate::db::CreateBillingPreferencesParams {
170 max_monthly_llm_usage_spending_in_cents,
171 model_request_overages_enabled: body.model_request_overages_enabled,
172 model_request_overages_spend_limit_in_cents,
173 },
174 )
175 .await?
176 };
177
178 SnowflakeRow::new(
179 "Billing Preferences Updated",
180 Some(user.metrics_id),
181 user.admin,
182 None,
183 json!({
184 "user_id": user.id,
185 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
186 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
187 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
188 }),
189 )
190 .write(&app.kinesis_client, &app.config.kinesis_stream)
191 .await
192 .log_err();
193
194 rpc_server.refresh_llm_tokens_for_user(user.id).await;
195
196 Ok(Json(BillingPreferencesResponse {
197 trial_started_at: billing_customer
198 .and_then(|billing_customer| billing_customer.trial_started_at)
199 .map(|trial_started_at| {
200 trial_started_at
201 .and_utc()
202 .to_rfc3339_opts(SecondsFormat::Millis, true)
203 }),
204 max_monthly_llm_usage_spending_in_cents: billing_preferences
205 .max_monthly_llm_usage_spending_in_cents,
206 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
207 model_request_overages_spend_limit_in_cents: billing_preferences
208 .model_request_overages_spend_limit_in_cents,
209 }))
210}
211
212#[derive(Debug, Deserialize)]
213struct ListBillingSubscriptionsParams {
214 github_user_id: i32,
215}
216
217#[derive(Debug, Serialize)]
218struct BillingSubscriptionJson {
219 id: BillingSubscriptionId,
220 name: String,
221 status: StripeSubscriptionStatus,
222 period: Option<BillingSubscriptionPeriodJson>,
223 trial_end_at: Option<String>,
224 cancel_at: Option<String>,
225 /// Whether this subscription can be canceled.
226 is_cancelable: bool,
227}
228
229#[derive(Debug, Serialize)]
230struct BillingSubscriptionPeriodJson {
231 start_at: String,
232 end_at: String,
233}
234
235#[derive(Debug, Serialize)]
236struct ListBillingSubscriptionsResponse {
237 subscriptions: Vec<BillingSubscriptionJson>,
238}
239
240async fn list_billing_subscriptions(
241 Extension(app): Extension<Arc<AppState>>,
242 Query(params): Query<ListBillingSubscriptionsParams>,
243) -> Result<Json<ListBillingSubscriptionsResponse>> {
244 let user = app
245 .db
246 .get_user_by_github_user_id(params.github_user_id)
247 .await?
248 .context("user not found")?;
249
250 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
251
252 Ok(Json(ListBillingSubscriptionsResponse {
253 subscriptions: subscriptions
254 .into_iter()
255 .map(|subscription| BillingSubscriptionJson {
256 id: subscription.id,
257 name: match subscription.kind {
258 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
259 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
260 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
261 None => "Zed LLM Usage".to_string(),
262 },
263 status: subscription.stripe_subscription_status,
264 period: maybe!({
265 let start_at = subscription.current_period_start_at()?;
266 let end_at = subscription.current_period_end_at()?;
267
268 Some(BillingSubscriptionPeriodJson {
269 start_at: start_at.to_rfc3339_opts(SecondsFormat::Millis, true),
270 end_at: end_at.to_rfc3339_opts(SecondsFormat::Millis, true),
271 })
272 }),
273 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
274 maybe!({
275 let end_at = subscription.stripe_current_period_end?;
276 let end_at = DateTime::from_timestamp(end_at, 0)?;
277
278 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
279 })
280 } else {
281 None
282 },
283 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
284 cancel_at
285 .and_utc()
286 .to_rfc3339_opts(SecondsFormat::Millis, true)
287 }),
288 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
289 && subscription.stripe_subscription_status.is_cancelable()
290 && subscription.stripe_cancel_at.is_none(),
291 })
292 .collect(),
293 }))
294}
295
296#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
297#[serde(rename_all = "snake_case")]
298enum ProductCode {
299 ZedPro,
300 ZedProTrial,
301}
302
303#[derive(Debug, Deserialize)]
304struct CreateBillingSubscriptionBody {
305 github_user_id: i32,
306 product: ProductCode,
307}
308
309#[derive(Debug, Serialize)]
310struct CreateBillingSubscriptionResponse {
311 checkout_session_url: String,
312}
313
314/// Initiates a Stripe Checkout session for creating a billing subscription.
315async fn create_billing_subscription(
316 Extension(app): Extension<Arc<AppState>>,
317 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
318) -> Result<Json<CreateBillingSubscriptionResponse>> {
319 let user = app
320 .db
321 .get_user_by_github_user_id(body.github_user_id)
322 .await?
323 .context("user not found")?;
324
325 let Some(stripe_billing) = app.stripe_billing.clone() else {
326 log::error!("failed to retrieve Stripe billing object");
327 Err(Error::http(
328 StatusCode::NOT_IMPLEMENTED,
329 "not supported".into(),
330 ))?
331 };
332
333 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
334 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
335 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
336
337 if !is_checkout_allowed {
338 return Err(Error::http(
339 StatusCode::CONFLICT,
340 "user already has an active subscription".into(),
341 ));
342 }
343 }
344
345 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
346 if let Some(existing_billing_customer) = &existing_billing_customer {
347 if existing_billing_customer.has_overdue_invoices {
348 return Err(Error::http(
349 StatusCode::PAYMENT_REQUIRED,
350 "user has overdue invoices".into(),
351 ));
352 }
353 }
354
355 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
356 let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
357 if let Some(email) = user.email_address.as_deref() {
358 stripe_billing
359 .client()
360 .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
361 .await
362 // Update of email address is best-effort - continue checkout even if it fails
363 .context("error updating stripe customer email address")
364 .log_err();
365 }
366 customer_id
367 } else {
368 stripe_billing
369 .find_or_create_customer_by_email(user.email_address.as_deref())
370 .await?
371 };
372
373 let success_url = format!(
374 "{}/account?checkout_complete=1",
375 app.config.zed_dot_dev_url()
376 );
377
378 let checkout_session_url = match body.product {
379 ProductCode::ZedPro => {
380 stripe_billing
381 .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
382 .await?
383 }
384 ProductCode::ZedProTrial => {
385 if let Some(existing_billing_customer) = &existing_billing_customer {
386 if existing_billing_customer.trial_started_at.is_some() {
387 return Err(Error::http(
388 StatusCode::FORBIDDEN,
389 "user already used free trial".into(),
390 ));
391 }
392 }
393
394 let feature_flags = app.db.get_user_flags(user.id).await?;
395
396 stripe_billing
397 .checkout_with_zed_pro_trial(
398 &customer_id,
399 &user.github_login,
400 feature_flags,
401 &success_url,
402 )
403 .await?
404 }
405 };
406
407 Ok(Json(CreateBillingSubscriptionResponse {
408 checkout_session_url,
409 }))
410}
411
412#[derive(Debug, PartialEq, Deserialize)]
413#[serde(rename_all = "snake_case")]
414enum ManageSubscriptionIntent {
415 /// The user intends to manage their subscription.
416 ///
417 /// This will open the Stripe billing portal without putting the user in a specific flow.
418 ManageSubscription,
419 /// The user intends to update their payment method.
420 UpdatePaymentMethod,
421 /// The user intends to upgrade to Zed Pro.
422 UpgradeToPro,
423 /// The user intends to cancel their subscription.
424 Cancel,
425 /// The user intends to stop the cancellation of their subscription.
426 StopCancellation,
427}
428
429#[derive(Debug, Deserialize)]
430struct ManageBillingSubscriptionBody {
431 github_user_id: i32,
432 intent: ManageSubscriptionIntent,
433 /// The ID of the subscription to manage.
434 subscription_id: BillingSubscriptionId,
435 redirect_to: Option<String>,
436}
437
438#[derive(Debug, Serialize)]
439struct ManageBillingSubscriptionResponse {
440 billing_portal_session_url: Option<String>,
441}
442
443/// Initiates a Stripe customer portal session for managing a billing subscription.
444async fn manage_billing_subscription(
445 Extension(app): Extension<Arc<AppState>>,
446 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
447) -> Result<Json<ManageBillingSubscriptionResponse>> {
448 let user = app
449 .db
450 .get_user_by_github_user_id(body.github_user_id)
451 .await?
452 .context("user not found")?;
453
454 let Some(stripe_client) = app.real_stripe_client.clone() else {
455 log::error!("failed to retrieve Stripe client");
456 Err(Error::http(
457 StatusCode::NOT_IMPLEMENTED,
458 "not supported".into(),
459 ))?
460 };
461
462 let Some(stripe_billing) = app.stripe_billing.clone() else {
463 log::error!("failed to retrieve Stripe billing object");
464 Err(Error::http(
465 StatusCode::NOT_IMPLEMENTED,
466 "not supported".into(),
467 ))?
468 };
469
470 let customer = app
471 .db
472 .get_billing_customer_by_user_id(user.id)
473 .await?
474 .context("billing customer not found")?;
475 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
476 .context("failed to parse customer ID")?;
477
478 let subscription = app
479 .db
480 .get_billing_subscription_by_id(body.subscription_id)
481 .await?
482 .context("subscription not found")?;
483 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
484 .context("failed to parse subscription ID")?;
485
486 if body.intent == ManageSubscriptionIntent::StopCancellation {
487 let updated_stripe_subscription = Subscription::update(
488 &stripe_client,
489 &subscription_id,
490 stripe::UpdateSubscription {
491 cancel_at_period_end: Some(false),
492 ..Default::default()
493 },
494 )
495 .await?;
496
497 app.db
498 .update_billing_subscription(
499 subscription.id,
500 &UpdateBillingSubscriptionParams {
501 stripe_cancel_at: ActiveValue::set(
502 updated_stripe_subscription
503 .cancel_at
504 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
505 .map(|time| time.naive_utc()),
506 ),
507 ..Default::default()
508 },
509 )
510 .await?;
511
512 return Ok(Json(ManageBillingSubscriptionResponse {
513 billing_portal_session_url: None,
514 }));
515 }
516
517 let flow = match body.intent {
518 ManageSubscriptionIntent::ManageSubscription => None,
519 ManageSubscriptionIntent::UpgradeToPro => {
520 let zed_pro_price_id: stripe::PriceId =
521 stripe_billing.zed_pro_price_id().await?.try_into()?;
522 let zed_free_price_id: stripe::PriceId =
523 stripe_billing.zed_free_price_id().await?.try_into()?;
524
525 let stripe_subscription =
526 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
527
528 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
529 && stripe_subscription.items.data.iter().any(|item| {
530 item.price
531 .as_ref()
532 .map_or(false, |price| price.id == zed_pro_price_id)
533 });
534 if is_on_zed_pro_trial {
535 let payment_methods = PaymentMethod::list(
536 &stripe_client,
537 &stripe::ListPaymentMethods {
538 customer: Some(stripe_subscription.customer.id()),
539 ..Default::default()
540 },
541 )
542 .await?;
543
544 let has_payment_method = !payment_methods.data.is_empty();
545 if !has_payment_method {
546 return Err(Error::http(
547 StatusCode::BAD_REQUEST,
548 "missing payment method".into(),
549 ));
550 }
551
552 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
553 Subscription::update(
554 &stripe_client,
555 &stripe_subscription.id,
556 stripe::UpdateSubscription {
557 trial_end: Some(stripe::Scheduled::now()),
558 ..Default::default()
559 },
560 )
561 .await?;
562
563 return Ok(Json(ManageBillingSubscriptionResponse {
564 billing_portal_session_url: None,
565 }));
566 }
567
568 let subscription_item_to_update = stripe_subscription
569 .items
570 .data
571 .iter()
572 .find_map(|item| {
573 let price = item.price.as_ref()?;
574
575 if price.id == zed_free_price_id {
576 Some(item.id.clone())
577 } else {
578 None
579 }
580 })
581 .context("No subscription item to update")?;
582
583 Some(CreateBillingPortalSessionFlowData {
584 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
585 subscription_update_confirm: Some(
586 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
587 subscription: subscription.stripe_subscription_id,
588 items: vec![
589 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
590 id: subscription_item_to_update.to_string(),
591 price: Some(zed_pro_price_id.to_string()),
592 quantity: Some(1),
593 },
594 ],
595 discounts: None,
596 },
597 ),
598 ..Default::default()
599 })
600 }
601 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
602 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
603 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
604 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
605 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
606 return_url: format!(
607 "{}{path}",
608 app.config.zed_dot_dev_url(),
609 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
610 ),
611 }),
612 ..Default::default()
613 }),
614 ..Default::default()
615 }),
616 ManageSubscriptionIntent::Cancel => {
617 if subscription.kind == Some(SubscriptionKind::ZedFree) {
618 return Err(Error::http(
619 StatusCode::BAD_REQUEST,
620 "free subscription cannot be canceled".into(),
621 ));
622 }
623
624 Some(CreateBillingPortalSessionFlowData {
625 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
626 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
627 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
628 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
629 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
630 }),
631 ..Default::default()
632 }),
633 subscription_cancel: Some(
634 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
635 subscription: subscription.stripe_subscription_id,
636 retention: None,
637 },
638 ),
639 ..Default::default()
640 })
641 }
642 ManageSubscriptionIntent::StopCancellation => unreachable!(),
643 };
644
645 let mut params = CreateBillingPortalSession::new(customer_id);
646 params.flow_data = flow;
647 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
648 params.return_url = Some(&return_url);
649
650 let session = BillingPortalSession::create(&stripe_client, params).await?;
651
652 Ok(Json(ManageBillingSubscriptionResponse {
653 billing_portal_session_url: Some(session.url),
654 }))
655}
656
657#[derive(Debug, Deserialize)]
658struct SyncBillingSubscriptionBody {
659 github_user_id: i32,
660}
661
662#[derive(Debug, Serialize)]
663struct SyncBillingSubscriptionResponse {
664 stripe_customer_id: String,
665}
666
667async fn sync_billing_subscription(
668 Extension(app): Extension<Arc<AppState>>,
669 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
670) -> Result<Json<SyncBillingSubscriptionResponse>> {
671 let Some(stripe_client) = app.stripe_client.clone() else {
672 log::error!("failed to retrieve Stripe client");
673 Err(Error::http(
674 StatusCode::NOT_IMPLEMENTED,
675 "not supported".into(),
676 ))?
677 };
678
679 let user = app
680 .db
681 .get_user_by_github_user_id(body.github_user_id)
682 .await?
683 .context("user not found")?;
684
685 let billing_customer = app
686 .db
687 .get_billing_customer_by_user_id(user.id)
688 .await?
689 .context("billing customer not found")?;
690 let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
691
692 let subscriptions = stripe_client
693 .list_subscriptions_for_customer(&stripe_customer_id)
694 .await?;
695
696 for subscription in subscriptions {
697 let subscription_id = subscription.id.clone();
698
699 sync_subscription(&app, &stripe_client, subscription)
700 .await
701 .with_context(|| {
702 format!(
703 "failed to sync subscription {subscription_id} for user {}",
704 user.id,
705 )
706 })?;
707 }
708
709 Ok(Json(SyncBillingSubscriptionResponse {
710 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
711 }))
712}
713
714/// The amount of time we wait in between each poll of Stripe events.
715///
716/// This value should strike a balance between:
717/// 1. Being short enough that we update quickly when something in Stripe changes
718/// 2. Being long enough that we don't eat into our rate limits.
719///
720/// As a point of reference, the Sequin folks say they have this at **500ms**:
721///
722/// > We poll the Stripe /events endpoint every 500ms per account
723/// >
724/// > — https://blog.sequinstream.com/events-not-webhooks/
725const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
726
727/// The maximum number of events to return per page.
728///
729/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
730///
731/// > Limit can range between 1 and 100, and the default is 10.
732const EVENTS_LIMIT_PER_PAGE: u64 = 100;
733
734/// The number of pages consisting entirely of already-processed events that we
735/// will see before we stop retrieving events.
736///
737/// This is used to prevent over-fetching the Stripe events API for events we've
738/// already seen and processed.
739const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
740
741/// Polls the Stripe events API periodically to reconcile the records in our
742/// database with the data in Stripe.
743pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
744 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
745 log::warn!("failed to retrieve Stripe client");
746 return;
747 };
748 let Some(stripe_client) = app.stripe_client.clone() else {
749 log::warn!("failed to retrieve Stripe client");
750 return;
751 };
752
753 let executor = app.executor.clone();
754 executor.spawn_detached({
755 let executor = executor.clone();
756 async move {
757 loop {
758 poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
759 .await
760 .log_err();
761
762 executor.sleep(POLL_EVENTS_INTERVAL).await;
763 }
764 }
765 });
766}
767
768async fn poll_stripe_events(
769 app: &Arc<AppState>,
770 rpc_server: &Arc<Server>,
771 stripe_client: &Arc<dyn StripeClient>,
772 real_stripe_client: &stripe::Client,
773) -> anyhow::Result<()> {
774 fn event_type_to_string(event_type: EventType) -> String {
775 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
776 // so we need to unquote it.
777 event_type.to_string().trim_matches('"').to_string()
778 }
779
780 let event_types = [
781 EventType::CustomerCreated,
782 EventType::CustomerUpdated,
783 EventType::CustomerSubscriptionCreated,
784 EventType::CustomerSubscriptionUpdated,
785 EventType::CustomerSubscriptionPaused,
786 EventType::CustomerSubscriptionResumed,
787 EventType::CustomerSubscriptionDeleted,
788 ]
789 .into_iter()
790 .map(event_type_to_string)
791 .collect::<Vec<_>>();
792
793 let mut pages_of_already_processed_events = 0;
794 let mut unprocessed_events = Vec::new();
795
796 log::info!(
797 "Stripe events: starting retrieval for {}",
798 event_types.join(", ")
799 );
800 let mut params = ListEvents::new();
801 params.types = Some(event_types.clone());
802 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
803
804 let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
805 .await?
806 .paginate(params);
807
808 loop {
809 let processed_event_ids = {
810 let event_ids = event_pages
811 .page
812 .data
813 .iter()
814 .map(|event| event.id.as_str())
815 .collect::<Vec<_>>();
816 app.db
817 .get_processed_stripe_events_by_event_ids(&event_ids)
818 .await?
819 .into_iter()
820 .map(|event| event.stripe_event_id)
821 .collect::<Vec<_>>()
822 };
823
824 let mut processed_events_in_page = 0;
825 let events_in_page = event_pages.page.data.len();
826 for event in &event_pages.page.data {
827 if processed_event_ids.contains(&event.id.to_string()) {
828 processed_events_in_page += 1;
829 log::debug!("Stripe events: already processed '{}', skipping", event.id);
830 } else {
831 unprocessed_events.push(event.clone());
832 }
833 }
834
835 if processed_events_in_page == events_in_page {
836 pages_of_already_processed_events += 1;
837 }
838
839 if event_pages.page.has_more {
840 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
841 {
842 log::info!(
843 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
844 );
845 break;
846 } else {
847 log::info!("Stripe events: retrieving next page");
848 event_pages = event_pages.next(&real_stripe_client).await?;
849 }
850 } else {
851 break;
852 }
853 }
854
855 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
856
857 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
858 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
859
860 for event in unprocessed_events {
861 let event_id = event.id.clone();
862 let processed_event_params = CreateProcessedStripeEventParams {
863 stripe_event_id: event.id.to_string(),
864 stripe_event_type: event_type_to_string(event.type_),
865 stripe_event_created_timestamp: event.created,
866 };
867
868 // If the event has happened too far in the past, we don't want to
869 // process it and risk overwriting other more-recent updates.
870 //
871 // 1 day was chosen arbitrarily. This could be made longer or shorter.
872 let one_day = Duration::from_secs(24 * 60 * 60);
873 let a_day_ago = Utc::now() - one_day;
874 if a_day_ago.timestamp() > event.created {
875 log::info!(
876 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
877 event_id
878 );
879 app.db
880 .create_processed_stripe_event(&processed_event_params)
881 .await?;
882
883 continue;
884 }
885
886 let process_result = match event.type_ {
887 EventType::CustomerCreated | EventType::CustomerUpdated => {
888 handle_customer_event(app, real_stripe_client, event).await
889 }
890 EventType::CustomerSubscriptionCreated
891 | EventType::CustomerSubscriptionUpdated
892 | EventType::CustomerSubscriptionPaused
893 | EventType::CustomerSubscriptionResumed
894 | EventType::CustomerSubscriptionDeleted => {
895 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
896 }
897 _ => Ok(()),
898 };
899
900 if let Some(()) = process_result
901 .with_context(|| format!("failed to process event {event_id} successfully"))
902 .log_err()
903 {
904 app.db
905 .create_processed_stripe_event(&processed_event_params)
906 .await?;
907 }
908 }
909
910 Ok(())
911}
912
913async fn handle_customer_event(
914 app: &Arc<AppState>,
915 _stripe_client: &stripe::Client,
916 event: stripe::Event,
917) -> anyhow::Result<()> {
918 let EventObject::Customer(customer) = event.data.object else {
919 bail!("unexpected event payload for {}", event.id);
920 };
921
922 log::info!("handling Stripe {} event: {}", event.type_, event.id);
923
924 let Some(email) = customer.email else {
925 log::info!("Stripe customer has no email: skipping");
926 return Ok(());
927 };
928
929 let Some(user) = app.db.get_user_by_email(&email).await? else {
930 log::info!("no user found for email: skipping");
931 return Ok(());
932 };
933
934 if let Some(existing_customer) = app
935 .db
936 .get_billing_customer_by_stripe_customer_id(&customer.id)
937 .await?
938 {
939 app.db
940 .update_billing_customer(
941 existing_customer.id,
942 &UpdateBillingCustomerParams {
943 // For now we just leave the information as-is, as it is not
944 // likely to change.
945 ..Default::default()
946 },
947 )
948 .await?;
949 } else {
950 app.db
951 .create_billing_customer(&CreateBillingCustomerParams {
952 user_id: user.id,
953 stripe_customer_id: customer.id.to_string(),
954 })
955 .await?;
956 }
957
958 Ok(())
959}
960
961async fn sync_subscription(
962 app: &Arc<AppState>,
963 stripe_client: &Arc<dyn StripeClient>,
964 subscription: StripeSubscription,
965) -> anyhow::Result<billing_customer::Model> {
966 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
967 stripe_billing
968 .determine_subscription_kind(&subscription)
969 .await
970 } else {
971 None
972 };
973
974 let billing_customer =
975 find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
976 .await?
977 .context("billing customer not found")?;
978
979 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
980 if subscription.status == SubscriptionStatus::Trialing {
981 let current_period_start =
982 DateTime::from_timestamp(subscription.current_period_start, 0)
983 .context("No trial subscription period start")?;
984
985 app.db
986 .update_billing_customer(
987 billing_customer.id,
988 &UpdateBillingCustomerParams {
989 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
990 ..Default::default()
991 },
992 )
993 .await?;
994 }
995 }
996
997 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
998 && subscription
999 .cancellation_details
1000 .as_ref()
1001 .and_then(|details| details.reason)
1002 .map_or(false, |reason| {
1003 reason == StripeCancellationDetailsReason::PaymentFailed
1004 });
1005
1006 if was_canceled_due_to_payment_failure {
1007 app.db
1008 .update_billing_customer(
1009 billing_customer.id,
1010 &UpdateBillingCustomerParams {
1011 has_overdue_invoices: ActiveValue::set(true),
1012 ..Default::default()
1013 },
1014 )
1015 .await?;
1016 }
1017
1018 if let Some(existing_subscription) = app
1019 .db
1020 .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
1021 .await?
1022 {
1023 app.db
1024 .update_billing_subscription(
1025 existing_subscription.id,
1026 &UpdateBillingSubscriptionParams {
1027 billing_customer_id: ActiveValue::set(billing_customer.id),
1028 kind: ActiveValue::set(subscription_kind),
1029 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1030 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1031 stripe_cancel_at: ActiveValue::set(
1032 subscription
1033 .cancel_at
1034 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1035 .map(|time| time.naive_utc()),
1036 ),
1037 stripe_cancellation_reason: ActiveValue::set(
1038 subscription
1039 .cancellation_details
1040 .and_then(|details| details.reason)
1041 .map(|reason| reason.into()),
1042 ),
1043 stripe_current_period_start: ActiveValue::set(Some(
1044 subscription.current_period_start,
1045 )),
1046 stripe_current_period_end: ActiveValue::set(Some(
1047 subscription.current_period_end,
1048 )),
1049 },
1050 )
1051 .await?;
1052 } else {
1053 if let Some(existing_subscription) = app
1054 .db
1055 .get_active_billing_subscription(billing_customer.user_id)
1056 .await?
1057 {
1058 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1059 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1060 {
1061 let stripe_subscription_id = StripeSubscriptionId(
1062 existing_subscription.stripe_subscription_id.clone().into(),
1063 );
1064
1065 stripe_client
1066 .cancel_subscription(&stripe_subscription_id)
1067 .await?;
1068 } else {
1069 // If the user already has an active billing subscription, ignore the
1070 // event and return an `Ok` to signal that it was processed
1071 // successfully.
1072 //
1073 // There is the possibility that this could cause us to not create a
1074 // subscription in the following scenario:
1075 //
1076 // 1. User has an active subscription A
1077 // 2. User cancels subscription A
1078 // 3. User creates a new subscription B
1079 // 4. We process the new subscription B before the cancellation of subscription A
1080 // 5. User ends up with no subscriptions
1081 //
1082 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1083
1084 log::info!(
1085 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1086 user_id = billing_customer.user_id,
1087 subscription_id = subscription.id
1088 );
1089 return Ok(billing_customer);
1090 }
1091 }
1092
1093 app.db
1094 .create_billing_subscription(&CreateBillingSubscriptionParams {
1095 billing_customer_id: billing_customer.id,
1096 kind: subscription_kind,
1097 stripe_subscription_id: subscription.id.to_string(),
1098 stripe_subscription_status: subscription.status.into(),
1099 stripe_cancellation_reason: subscription
1100 .cancellation_details
1101 .and_then(|details| details.reason)
1102 .map(|reason| reason.into()),
1103 stripe_current_period_start: Some(subscription.current_period_start),
1104 stripe_current_period_end: Some(subscription.current_period_end),
1105 })
1106 .await?;
1107 }
1108
1109 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1110 if subscription.status == SubscriptionStatus::Canceled
1111 || subscription.status == SubscriptionStatus::Paused
1112 {
1113 let already_has_active_billing_subscription = app
1114 .db
1115 .has_active_billing_subscription(billing_customer.user_id)
1116 .await?;
1117 if !already_has_active_billing_subscription {
1118 let stripe_customer_id =
1119 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1120
1121 stripe_billing
1122 .subscribe_to_zed_free(stripe_customer_id)
1123 .await?;
1124 }
1125 }
1126 }
1127
1128 Ok(billing_customer)
1129}
1130
1131async fn handle_customer_subscription_event(
1132 app: &Arc<AppState>,
1133 rpc_server: &Arc<Server>,
1134 stripe_client: &Arc<dyn StripeClient>,
1135 event: stripe::Event,
1136) -> anyhow::Result<()> {
1137 let EventObject::Subscription(subscription) = event.data.object else {
1138 bail!("unexpected event payload for {}", event.id);
1139 };
1140
1141 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1142
1143 let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
1144
1145 // When the user's subscription changes, push down any changes to their plan.
1146 rpc_server
1147 .update_plan_for_user(billing_customer.user_id)
1148 .await
1149 .trace_err();
1150
1151 // When the user's subscription changes, we want to refresh their LLM tokens
1152 // to either grant/revoke access.
1153 rpc_server
1154 .refresh_llm_tokens_for_user(billing_customer.user_id)
1155 .await;
1156
1157 Ok(())
1158}
1159
1160#[derive(Debug, Deserialize)]
1161struct GetCurrentUsageParams {
1162 github_user_id: i32,
1163}
1164
1165#[derive(Debug, Serialize)]
1166struct UsageCounts {
1167 pub used: i32,
1168 pub limit: Option<i32>,
1169 pub remaining: Option<i32>,
1170}
1171
1172#[derive(Debug, Serialize)]
1173struct ModelRequestUsage {
1174 pub model: String,
1175 pub mode: CompletionMode,
1176 pub requests: i32,
1177}
1178
1179#[derive(Debug, Serialize)]
1180struct CurrentUsage {
1181 pub model_requests: UsageCounts,
1182 pub model_request_usage: Vec<ModelRequestUsage>,
1183 pub edit_predictions: UsageCounts,
1184}
1185
1186#[derive(Debug, Default, Serialize)]
1187struct GetCurrentUsageResponse {
1188 pub plan: String,
1189 pub current_usage: Option<CurrentUsage>,
1190}
1191
1192async fn get_current_usage(
1193 Extension(app): Extension<Arc<AppState>>,
1194 Query(params): Query<GetCurrentUsageParams>,
1195) -> Result<Json<GetCurrentUsageResponse>> {
1196 let user = app
1197 .db
1198 .get_user_by_github_user_id(params.github_user_id)
1199 .await?
1200 .context("user not found")?;
1201
1202 let feature_flags = app.db.get_user_flags(user.id).await?;
1203 let has_extended_trial = feature_flags
1204 .iter()
1205 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1206
1207 let Some(llm_db) = app.llm_db.clone() else {
1208 return Err(Error::http(
1209 StatusCode::NOT_IMPLEMENTED,
1210 "LLM database not available".into(),
1211 ));
1212 };
1213
1214 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1215 return Ok(Json(GetCurrentUsageResponse::default()));
1216 };
1217
1218 let subscription_period = maybe!({
1219 let period_start_at = subscription.current_period_start_at()?;
1220 let period_end_at = subscription.current_period_end_at()?;
1221
1222 Some((period_start_at, period_end_at))
1223 });
1224
1225 let Some((period_start_at, period_end_at)) = subscription_period else {
1226 return Ok(Json(GetCurrentUsageResponse::default()));
1227 };
1228
1229 let usage = llm_db
1230 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1231 .await?;
1232
1233 let plan = subscription
1234 .kind
1235 .map(Into::into)
1236 .unwrap_or(zed_llm_client::Plan::ZedFree);
1237
1238 let model_requests_limit = match plan.model_requests_limit() {
1239 zed_llm_client::UsageLimit::Limited(limit) => {
1240 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1241 1_000
1242 } else {
1243 limit
1244 };
1245
1246 Some(limit)
1247 }
1248 zed_llm_client::UsageLimit::Unlimited => None,
1249 };
1250
1251 let edit_predictions_limit = match plan.edit_predictions_limit() {
1252 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1253 zed_llm_client::UsageLimit::Unlimited => None,
1254 };
1255
1256 let Some(usage) = usage else {
1257 return Ok(Json(GetCurrentUsageResponse {
1258 plan: plan.as_str().to_string(),
1259 current_usage: Some(CurrentUsage {
1260 model_requests: UsageCounts {
1261 used: 0,
1262 limit: model_requests_limit,
1263 remaining: model_requests_limit,
1264 },
1265 model_request_usage: Vec::new(),
1266 edit_predictions: UsageCounts {
1267 used: 0,
1268 limit: edit_predictions_limit,
1269 remaining: edit_predictions_limit,
1270 },
1271 }),
1272 }));
1273 };
1274
1275 let subscription_usage_meters = llm_db
1276 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1277 .await?;
1278
1279 let model_request_usage = subscription_usage_meters
1280 .into_iter()
1281 .filter_map(|(usage_meter, _usage)| {
1282 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1283
1284 Some(ModelRequestUsage {
1285 model: model.name.clone(),
1286 mode: usage_meter.mode,
1287 requests: usage_meter.requests,
1288 })
1289 })
1290 .collect::<Vec<_>>();
1291
1292 Ok(Json(GetCurrentUsageResponse {
1293 plan: plan.as_str().to_string(),
1294 current_usage: Some(CurrentUsage {
1295 model_requests: UsageCounts {
1296 used: usage.model_requests,
1297 limit: model_requests_limit,
1298 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1299 },
1300 model_request_usage,
1301 edit_predictions: UsageCounts {
1302 used: usage.edit_predictions,
1303 limit: edit_predictions_limit,
1304 remaining: edit_predictions_limit
1305 .map(|limit| (limit - usage.edit_predictions).max(0)),
1306 },
1307 }),
1308 }))
1309}
1310
1311impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1312 fn from(value: SubscriptionStatus) -> Self {
1313 match value {
1314 SubscriptionStatus::Incomplete => Self::Incomplete,
1315 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1316 SubscriptionStatus::Trialing => Self::Trialing,
1317 SubscriptionStatus::Active => Self::Active,
1318 SubscriptionStatus::PastDue => Self::PastDue,
1319 SubscriptionStatus::Canceled => Self::Canceled,
1320 SubscriptionStatus::Unpaid => Self::Unpaid,
1321 SubscriptionStatus::Paused => Self::Paused,
1322 }
1323 }
1324}
1325
1326impl From<CancellationDetailsReason> for StripeCancellationReason {
1327 fn from(value: CancellationDetailsReason) -> Self {
1328 match value {
1329 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1330 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1331 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1332 }
1333 }
1334}
1335
1336/// Finds or creates a billing customer using the provided customer.
1337pub async fn find_or_create_billing_customer(
1338 app: &Arc<AppState>,
1339 stripe_client: &dyn StripeClient,
1340 customer_id: &StripeCustomerId,
1341) -> anyhow::Result<Option<billing_customer::Model>> {
1342 // If we already have a billing customer record associated with the Stripe customer,
1343 // there's nothing more we need to do.
1344 if let Some(billing_customer) = app
1345 .db
1346 .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
1347 .await?
1348 {
1349 return Ok(Some(billing_customer));
1350 }
1351
1352 let customer = stripe_client.get_customer(customer_id).await?;
1353
1354 let Some(email) = customer.email else {
1355 return Ok(None);
1356 };
1357
1358 let Some(user) = app.db.get_user_by_email(&email).await? else {
1359 return Ok(None);
1360 };
1361
1362 let billing_customer = app
1363 .db
1364 .create_billing_customer(&CreateBillingCustomerParams {
1365 user_id: user.id,
1366 stripe_customer_id: customer.id.to_string(),
1367 })
1368 .await?;
1369
1370 Ok(Some(billing_customer))
1371}
1372
1373const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1374
1375pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1376 let Some(stripe_billing) = app.stripe_billing.clone() else {
1377 log::warn!("failed to retrieve Stripe billing object");
1378 return;
1379 };
1380 let Some(llm_db) = app.llm_db.clone() else {
1381 log::warn!("failed to retrieve LLM database");
1382 return;
1383 };
1384
1385 let executor = app.executor.clone();
1386 executor.spawn_detached({
1387 let executor = executor.clone();
1388 async move {
1389 loop {
1390 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1391 .await
1392 .context("failed to sync LLM request usage to Stripe")
1393 .trace_err();
1394 executor
1395 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1396 .await;
1397 }
1398 }
1399 });
1400}
1401
1402async fn sync_model_request_usage_with_stripe(
1403 app: &Arc<AppState>,
1404 llm_db: &Arc<LlmDatabase>,
1405 stripe_billing: &Arc<StripeBilling>,
1406) -> anyhow::Result<()> {
1407 let staff_users = app.db.get_staff_users().await?;
1408 let staff_user_ids = staff_users
1409 .iter()
1410 .map(|user| user.id)
1411 .collect::<HashSet<UserId>>();
1412
1413 let usage_meters = llm_db
1414 .get_current_subscription_usage_meters(Utc::now())
1415 .await?;
1416 let usage_meters = usage_meters
1417 .into_iter()
1418 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1419 .collect::<Vec<_>>();
1420 let user_ids = usage_meters
1421 .iter()
1422 .map(|(_, usage)| usage.user_id)
1423 .collect::<HashSet<UserId>>();
1424 let billing_subscriptions = app
1425 .db
1426 .get_active_zed_pro_billing_subscriptions(user_ids)
1427 .await?;
1428
1429 let claude_sonnet_4 = stripe_billing
1430 .find_price_by_lookup_key("claude-sonnet-4-requests")
1431 .await?;
1432 let claude_sonnet_4_max = stripe_billing
1433 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1434 .await?;
1435 let claude_opus_4 = stripe_billing
1436 .find_price_by_lookup_key("claude-opus-4-requests")
1437 .await?;
1438 let claude_opus_4_max = stripe_billing
1439 .find_price_by_lookup_key("claude-opus-4-requests-max")
1440 .await?;
1441 let claude_3_5_sonnet = stripe_billing
1442 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1443 .await?;
1444 let claude_3_7_sonnet = stripe_billing
1445 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1446 .await?;
1447 let claude_3_7_sonnet_max = stripe_billing
1448 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1449 .await?;
1450
1451 for (usage_meter, usage) in usage_meters {
1452 maybe!(async {
1453 let Some((billing_customer, billing_subscription)) =
1454 billing_subscriptions.get(&usage.user_id)
1455 else {
1456 bail!(
1457 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1458 usage.user_id
1459 );
1460 };
1461
1462 let stripe_customer_id =
1463 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1464 let stripe_subscription_id =
1465 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1466
1467 let model = llm_db.model_by_id(usage_meter.model_id)?;
1468
1469 let (price, meter_event_name) = match model.name.as_str() {
1470 "claude-opus-4" => match usage_meter.mode {
1471 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1472 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1473 },
1474 "claude-sonnet-4" => match usage_meter.mode {
1475 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1476 CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
1477 },
1478 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1479 "claude-3-7-sonnet" => match usage_meter.mode {
1480 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1481 CompletionMode::Max => {
1482 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1483 }
1484 },
1485 model_name => {
1486 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1487 }
1488 };
1489
1490 stripe_billing
1491 .subscribe_to_price(&stripe_subscription_id, price)
1492 .await?;
1493 stripe_billing
1494 .bill_model_request_usage(
1495 &stripe_customer_id,
1496 meter_event_name,
1497 usage_meter.requests,
1498 )
1499 .await?;
1500
1501 Ok(())
1502 })
1503 .await
1504 .log_err();
1505 }
1506
1507 Ok(())
1508}