1use anyhow::{Context as _, bail};
2use axum::routing::put;
3use axum::{
4 Extension, Json, Router,
5 extract::{self, Query},
6 routing::{get, post},
7};
8use chrono::{DateTime, SecondsFormat, Utc};
9use collections::{HashMap, HashSet};
10use reqwest::StatusCode;
11use sea_orm::ActiveValue;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::{str::FromStr, sync::Arc, time::Duration};
15use stripe::{
16 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
17 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
18 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
20 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
21 CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
22 PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
23};
24use util::{ResultExt, maybe};
25use zed_llm_client::LanguageModelProvider;
26
27use crate::api::events::SnowflakeRow;
28use crate::db::billing_subscription::{
29 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
30};
31use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
32use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
33use crate::rpc::{ResultExt as _, Server};
34use crate::stripe_client::{
35 StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
36 StripeSubscriptionId, UpdateCustomerParams,
37};
38use crate::{AppState, Error, Result};
39use crate::{db::UserId, llm::db::LlmDatabase};
40use crate::{
41 db::{
42 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
43 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
44 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
45 },
46 stripe_billing::StripeBilling,
47};
48
49pub fn router() -> Router {
50 Router::new()
51 .route("/billing/preferences", put(update_billing_preferences))
52 .route(
53 "/billing/subscriptions",
54 get(list_billing_subscriptions).post(create_billing_subscription),
55 )
56 .route(
57 "/billing/subscriptions/manage",
58 post(manage_billing_subscription),
59 )
60 .route(
61 "/billing/subscriptions/sync",
62 post(sync_billing_subscription),
63 )
64 .route("/billing/usage", get(get_current_usage))
65}
66
67#[derive(Debug, Serialize)]
68struct BillingPreferencesResponse {
69 trial_started_at: Option<String>,
70 max_monthly_llm_usage_spending_in_cents: i32,
71 model_request_overages_enabled: bool,
72 model_request_overages_spend_limit_in_cents: i32,
73}
74
75#[derive(Debug, Deserialize)]
76struct UpdateBillingPreferencesBody {
77 github_user_id: i32,
78 #[serde(default)]
79 max_monthly_llm_usage_spending_in_cents: i32,
80 #[serde(default)]
81 model_request_overages_enabled: bool,
82 #[serde(default)]
83 model_request_overages_spend_limit_in_cents: i32,
84}
85
86async fn update_billing_preferences(
87 Extension(app): Extension<Arc<AppState>>,
88 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
89 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
90) -> Result<Json<BillingPreferencesResponse>> {
91 let user = app
92 .db
93 .get_user_by_github_user_id(body.github_user_id)
94 .await?
95 .context("user not found")?;
96
97 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
98
99 let max_monthly_llm_usage_spending_in_cents =
100 body.max_monthly_llm_usage_spending_in_cents.max(0);
101 let model_request_overages_spend_limit_in_cents =
102 body.model_request_overages_spend_limit_in_cents.max(0);
103
104 let billing_preferences =
105 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
106 app.db
107 .update_billing_preferences(
108 user.id,
109 &UpdateBillingPreferencesParams {
110 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
111 max_monthly_llm_usage_spending_in_cents,
112 ),
113 model_request_overages_enabled: ActiveValue::set(
114 body.model_request_overages_enabled,
115 ),
116 model_request_overages_spend_limit_in_cents: ActiveValue::set(
117 model_request_overages_spend_limit_in_cents,
118 ),
119 },
120 )
121 .await?
122 } else {
123 app.db
124 .create_billing_preferences(
125 user.id,
126 &crate::db::CreateBillingPreferencesParams {
127 max_monthly_llm_usage_spending_in_cents,
128 model_request_overages_enabled: body.model_request_overages_enabled,
129 model_request_overages_spend_limit_in_cents,
130 },
131 )
132 .await?
133 };
134
135 SnowflakeRow::new(
136 "Billing Preferences Updated",
137 Some(user.metrics_id),
138 user.admin,
139 None,
140 json!({
141 "user_id": user.id,
142 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
143 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
144 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
145 }),
146 )
147 .write(&app.kinesis_client, &app.config.kinesis_stream)
148 .await
149 .log_err();
150
151 rpc_server.refresh_llm_tokens_for_user(user.id).await;
152
153 Ok(Json(BillingPreferencesResponse {
154 trial_started_at: billing_customer
155 .and_then(|billing_customer| billing_customer.trial_started_at)
156 .map(|trial_started_at| {
157 trial_started_at
158 .and_utc()
159 .to_rfc3339_opts(SecondsFormat::Millis, true)
160 }),
161 max_monthly_llm_usage_spending_in_cents: billing_preferences
162 .max_monthly_llm_usage_spending_in_cents,
163 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
164 model_request_overages_spend_limit_in_cents: billing_preferences
165 .model_request_overages_spend_limit_in_cents,
166 }))
167}
168
169#[derive(Debug, Deserialize)]
170struct ListBillingSubscriptionsParams {
171 github_user_id: i32,
172}
173
174#[derive(Debug, Serialize)]
175struct BillingSubscriptionJson {
176 id: BillingSubscriptionId,
177 name: String,
178 status: StripeSubscriptionStatus,
179 period: Option<BillingSubscriptionPeriodJson>,
180 trial_end_at: Option<String>,
181 cancel_at: Option<String>,
182 /// Whether this subscription can be canceled.
183 is_cancelable: bool,
184}
185
186#[derive(Debug, Serialize)]
187struct BillingSubscriptionPeriodJson {
188 start_at: String,
189 end_at: String,
190}
191
192#[derive(Debug, Serialize)]
193struct ListBillingSubscriptionsResponse {
194 subscriptions: Vec<BillingSubscriptionJson>,
195}
196
197async fn list_billing_subscriptions(
198 Extension(app): Extension<Arc<AppState>>,
199 Query(params): Query<ListBillingSubscriptionsParams>,
200) -> Result<Json<ListBillingSubscriptionsResponse>> {
201 let user = app
202 .db
203 .get_user_by_github_user_id(params.github_user_id)
204 .await?
205 .context("user not found")?;
206
207 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
208
209 Ok(Json(ListBillingSubscriptionsResponse {
210 subscriptions: subscriptions
211 .into_iter()
212 .map(|subscription| BillingSubscriptionJson {
213 id: subscription.id,
214 name: match subscription.kind {
215 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
216 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
217 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
218 None => "Zed LLM Usage".to_string(),
219 },
220 status: subscription.stripe_subscription_status,
221 period: maybe!({
222 let start_at = subscription.current_period_start_at()?;
223 let end_at = subscription.current_period_end_at()?;
224
225 Some(BillingSubscriptionPeriodJson {
226 start_at: start_at.to_rfc3339_opts(SecondsFormat::Millis, true),
227 end_at: end_at.to_rfc3339_opts(SecondsFormat::Millis, true),
228 })
229 }),
230 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
231 maybe!({
232 let end_at = subscription.stripe_current_period_end?;
233 let end_at = DateTime::from_timestamp(end_at, 0)?;
234
235 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
236 })
237 } else {
238 None
239 },
240 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
241 cancel_at
242 .and_utc()
243 .to_rfc3339_opts(SecondsFormat::Millis, true)
244 }),
245 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
246 && subscription.stripe_subscription_status.is_cancelable()
247 && subscription.stripe_cancel_at.is_none(),
248 })
249 .collect(),
250 }))
251}
252
253#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
254#[serde(rename_all = "snake_case")]
255enum ProductCode {
256 ZedPro,
257 ZedProTrial,
258}
259
260#[derive(Debug, Deserialize)]
261struct CreateBillingSubscriptionBody {
262 github_user_id: i32,
263 product: ProductCode,
264}
265
266#[derive(Debug, Serialize)]
267struct CreateBillingSubscriptionResponse {
268 checkout_session_url: String,
269}
270
271/// Initiates a Stripe Checkout session for creating a billing subscription.
272async fn create_billing_subscription(
273 Extension(app): Extension<Arc<AppState>>,
274 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
275) -> Result<Json<CreateBillingSubscriptionResponse>> {
276 let user = app
277 .db
278 .get_user_by_github_user_id(body.github_user_id)
279 .await?
280 .context("user not found")?;
281
282 let Some(stripe_billing) = app.stripe_billing.clone() else {
283 log::error!("failed to retrieve Stripe billing object");
284 Err(Error::http(
285 StatusCode::NOT_IMPLEMENTED,
286 "not supported".into(),
287 ))?
288 };
289
290 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
291 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
292 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
293
294 if !is_checkout_allowed {
295 return Err(Error::http(
296 StatusCode::CONFLICT,
297 "user already has an active subscription".into(),
298 ));
299 }
300 }
301
302 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
303 if let Some(existing_billing_customer) = &existing_billing_customer {
304 if existing_billing_customer.has_overdue_invoices {
305 return Err(Error::http(
306 StatusCode::PAYMENT_REQUIRED,
307 "user has overdue invoices".into(),
308 ));
309 }
310 }
311
312 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
313 let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
314 if let Some(email) = user.email_address.as_deref() {
315 stripe_billing
316 .client()
317 .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
318 .await
319 // Update of email address is best-effort - continue checkout even if it fails
320 .context("error updating stripe customer email address")
321 .log_err();
322 }
323 customer_id
324 } else {
325 stripe_billing
326 .find_or_create_customer_by_email(user.email_address.as_deref())
327 .await?
328 };
329
330 let success_url = format!(
331 "{}/account?checkout_complete=1",
332 app.config.zed_dot_dev_url()
333 );
334
335 let checkout_session_url = match body.product {
336 ProductCode::ZedPro => {
337 stripe_billing
338 .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
339 .await?
340 }
341 ProductCode::ZedProTrial => {
342 if let Some(existing_billing_customer) = &existing_billing_customer {
343 if existing_billing_customer.trial_started_at.is_some() {
344 return Err(Error::http(
345 StatusCode::FORBIDDEN,
346 "user already used free trial".into(),
347 ));
348 }
349 }
350
351 let feature_flags = app.db.get_user_flags(user.id).await?;
352
353 stripe_billing
354 .checkout_with_zed_pro_trial(
355 &customer_id,
356 &user.github_login,
357 feature_flags,
358 &success_url,
359 )
360 .await?
361 }
362 };
363
364 Ok(Json(CreateBillingSubscriptionResponse {
365 checkout_session_url,
366 }))
367}
368
369#[derive(Debug, PartialEq, Deserialize)]
370#[serde(rename_all = "snake_case")]
371enum ManageSubscriptionIntent {
372 /// The user intends to manage their subscription.
373 ///
374 /// This will open the Stripe billing portal without putting the user in a specific flow.
375 ManageSubscription,
376 /// The user intends to update their payment method.
377 UpdatePaymentMethod,
378 /// The user intends to upgrade to Zed Pro.
379 UpgradeToPro,
380 /// The user intends to cancel their subscription.
381 Cancel,
382 /// The user intends to stop the cancellation of their subscription.
383 StopCancellation,
384}
385
386#[derive(Debug, Deserialize)]
387struct ManageBillingSubscriptionBody {
388 github_user_id: i32,
389 intent: ManageSubscriptionIntent,
390 /// The ID of the subscription to manage.
391 subscription_id: BillingSubscriptionId,
392 redirect_to: Option<String>,
393}
394
395#[derive(Debug, Serialize)]
396struct ManageBillingSubscriptionResponse {
397 billing_portal_session_url: Option<String>,
398}
399
400/// Initiates a Stripe customer portal session for managing a billing subscription.
401async fn manage_billing_subscription(
402 Extension(app): Extension<Arc<AppState>>,
403 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
404) -> Result<Json<ManageBillingSubscriptionResponse>> {
405 let user = app
406 .db
407 .get_user_by_github_user_id(body.github_user_id)
408 .await?
409 .context("user not found")?;
410
411 let Some(stripe_client) = app.real_stripe_client.clone() else {
412 log::error!("failed to retrieve Stripe client");
413 Err(Error::http(
414 StatusCode::NOT_IMPLEMENTED,
415 "not supported".into(),
416 ))?
417 };
418
419 let Some(stripe_billing) = app.stripe_billing.clone() else {
420 log::error!("failed to retrieve Stripe billing object");
421 Err(Error::http(
422 StatusCode::NOT_IMPLEMENTED,
423 "not supported".into(),
424 ))?
425 };
426
427 let customer = app
428 .db
429 .get_billing_customer_by_user_id(user.id)
430 .await?
431 .context("billing customer not found")?;
432 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
433 .context("failed to parse customer ID")?;
434
435 let subscription = app
436 .db
437 .get_billing_subscription_by_id(body.subscription_id)
438 .await?
439 .context("subscription not found")?;
440 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
441 .context("failed to parse subscription ID")?;
442
443 if body.intent == ManageSubscriptionIntent::StopCancellation {
444 let updated_stripe_subscription = Subscription::update(
445 &stripe_client,
446 &subscription_id,
447 stripe::UpdateSubscription {
448 cancel_at_period_end: Some(false),
449 ..Default::default()
450 },
451 )
452 .await?;
453
454 app.db
455 .update_billing_subscription(
456 subscription.id,
457 &UpdateBillingSubscriptionParams {
458 stripe_cancel_at: ActiveValue::set(
459 updated_stripe_subscription
460 .cancel_at
461 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
462 .map(|time| time.naive_utc()),
463 ),
464 ..Default::default()
465 },
466 )
467 .await?;
468
469 return Ok(Json(ManageBillingSubscriptionResponse {
470 billing_portal_session_url: None,
471 }));
472 }
473
474 let flow = match body.intent {
475 ManageSubscriptionIntent::ManageSubscription => None,
476 ManageSubscriptionIntent::UpgradeToPro => {
477 let zed_pro_price_id: stripe::PriceId =
478 stripe_billing.zed_pro_price_id().await?.try_into()?;
479 let zed_free_price_id: stripe::PriceId =
480 stripe_billing.zed_free_price_id().await?.try_into()?;
481
482 let stripe_subscription =
483 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
484
485 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
486 && stripe_subscription.items.data.iter().any(|item| {
487 item.price
488 .as_ref()
489 .map_or(false, |price| price.id == zed_pro_price_id)
490 });
491 if is_on_zed_pro_trial {
492 let payment_methods = PaymentMethod::list(
493 &stripe_client,
494 &stripe::ListPaymentMethods {
495 customer: Some(stripe_subscription.customer.id()),
496 ..Default::default()
497 },
498 )
499 .await?;
500
501 let has_payment_method = !payment_methods.data.is_empty();
502 if !has_payment_method {
503 return Err(Error::http(
504 StatusCode::BAD_REQUEST,
505 "missing payment method".into(),
506 ));
507 }
508
509 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
510 Subscription::update(
511 &stripe_client,
512 &stripe_subscription.id,
513 stripe::UpdateSubscription {
514 trial_end: Some(stripe::Scheduled::now()),
515 ..Default::default()
516 },
517 )
518 .await?;
519
520 return Ok(Json(ManageBillingSubscriptionResponse {
521 billing_portal_session_url: None,
522 }));
523 }
524
525 let subscription_item_to_update = stripe_subscription
526 .items
527 .data
528 .iter()
529 .find_map(|item| {
530 let price = item.price.as_ref()?;
531
532 if price.id == zed_free_price_id {
533 Some(item.id.clone())
534 } else {
535 None
536 }
537 })
538 .context("No subscription item to update")?;
539
540 Some(CreateBillingPortalSessionFlowData {
541 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
542 subscription_update_confirm: Some(
543 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
544 subscription: subscription.stripe_subscription_id,
545 items: vec![
546 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
547 id: subscription_item_to_update.to_string(),
548 price: Some(zed_pro_price_id.to_string()),
549 quantity: Some(1),
550 },
551 ],
552 discounts: None,
553 },
554 ),
555 ..Default::default()
556 })
557 }
558 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
559 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
560 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
561 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
562 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
563 return_url: format!(
564 "{}{path}",
565 app.config.zed_dot_dev_url(),
566 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
567 ),
568 }),
569 ..Default::default()
570 }),
571 ..Default::default()
572 }),
573 ManageSubscriptionIntent::Cancel => {
574 if subscription.kind == Some(SubscriptionKind::ZedFree) {
575 return Err(Error::http(
576 StatusCode::BAD_REQUEST,
577 "free subscription cannot be canceled".into(),
578 ));
579 }
580
581 Some(CreateBillingPortalSessionFlowData {
582 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
583 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
584 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
585 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
586 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
587 }),
588 ..Default::default()
589 }),
590 subscription_cancel: Some(
591 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
592 subscription: subscription.stripe_subscription_id,
593 retention: None,
594 },
595 ),
596 ..Default::default()
597 })
598 }
599 ManageSubscriptionIntent::StopCancellation => unreachable!(),
600 };
601
602 let mut params = CreateBillingPortalSession::new(customer_id);
603 params.flow_data = flow;
604 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
605 params.return_url = Some(&return_url);
606
607 let session = BillingPortalSession::create(&stripe_client, params).await?;
608
609 Ok(Json(ManageBillingSubscriptionResponse {
610 billing_portal_session_url: Some(session.url),
611 }))
612}
613
614#[derive(Debug, Deserialize)]
615struct SyncBillingSubscriptionBody {
616 github_user_id: i32,
617}
618
619#[derive(Debug, Serialize)]
620struct SyncBillingSubscriptionResponse {
621 stripe_customer_id: String,
622}
623
624async fn sync_billing_subscription(
625 Extension(app): Extension<Arc<AppState>>,
626 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
627) -> Result<Json<SyncBillingSubscriptionResponse>> {
628 let Some(stripe_client) = app.stripe_client.clone() else {
629 log::error!("failed to retrieve Stripe client");
630 Err(Error::http(
631 StatusCode::NOT_IMPLEMENTED,
632 "not supported".into(),
633 ))?
634 };
635
636 let user = app
637 .db
638 .get_user_by_github_user_id(body.github_user_id)
639 .await?
640 .context("user not found")?;
641
642 let billing_customer = app
643 .db
644 .get_billing_customer_by_user_id(user.id)
645 .await?
646 .context("billing customer not found")?;
647 let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
648
649 let subscriptions = stripe_client
650 .list_subscriptions_for_customer(&stripe_customer_id)
651 .await?;
652
653 for subscription in subscriptions {
654 let subscription_id = subscription.id.clone();
655
656 sync_subscription(&app, &stripe_client, subscription)
657 .await
658 .with_context(|| {
659 format!(
660 "failed to sync subscription {subscription_id} for user {}",
661 user.id,
662 )
663 })?;
664 }
665
666 Ok(Json(SyncBillingSubscriptionResponse {
667 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
668 }))
669}
670
671/// The amount of time we wait in between each poll of Stripe events.
672///
673/// This value should strike a balance between:
674/// 1. Being short enough that we update quickly when something in Stripe changes
675/// 2. Being long enough that we don't eat into our rate limits.
676///
677/// As a point of reference, the Sequin folks say they have this at **500ms**:
678///
679/// > We poll the Stripe /events endpoint every 500ms per account
680/// >
681/// > — https://blog.sequinstream.com/events-not-webhooks/
682const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
683
684/// The maximum number of events to return per page.
685///
686/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
687///
688/// > Limit can range between 1 and 100, and the default is 10.
689const EVENTS_LIMIT_PER_PAGE: u64 = 100;
690
691/// The number of pages consisting entirely of already-processed events that we
692/// will see before we stop retrieving events.
693///
694/// This is used to prevent over-fetching the Stripe events API for events we've
695/// already seen and processed.
696const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
697
698/// Polls the Stripe events API periodically to reconcile the records in our
699/// database with the data in Stripe.
700pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
701 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
702 log::warn!("failed to retrieve Stripe client");
703 return;
704 };
705 let Some(stripe_client) = app.stripe_client.clone() else {
706 log::warn!("failed to retrieve Stripe client");
707 return;
708 };
709
710 let executor = app.executor.clone();
711 executor.spawn_detached({
712 let executor = executor.clone();
713 async move {
714 loop {
715 poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
716 .await
717 .log_err();
718
719 executor.sleep(POLL_EVENTS_INTERVAL).await;
720 }
721 }
722 });
723}
724
725async fn poll_stripe_events(
726 app: &Arc<AppState>,
727 rpc_server: &Arc<Server>,
728 stripe_client: &Arc<dyn StripeClient>,
729 real_stripe_client: &stripe::Client,
730) -> anyhow::Result<()> {
731 fn event_type_to_string(event_type: EventType) -> String {
732 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
733 // so we need to unquote it.
734 event_type.to_string().trim_matches('"').to_string()
735 }
736
737 let event_types = [
738 EventType::CustomerCreated,
739 EventType::CustomerUpdated,
740 EventType::CustomerSubscriptionCreated,
741 EventType::CustomerSubscriptionUpdated,
742 EventType::CustomerSubscriptionPaused,
743 EventType::CustomerSubscriptionResumed,
744 EventType::CustomerSubscriptionDeleted,
745 ]
746 .into_iter()
747 .map(event_type_to_string)
748 .collect::<Vec<_>>();
749
750 let mut pages_of_already_processed_events = 0;
751 let mut unprocessed_events = Vec::new();
752
753 log::info!(
754 "Stripe events: starting retrieval for {}",
755 event_types.join(", ")
756 );
757 let mut params = ListEvents::new();
758 params.types = Some(event_types.clone());
759 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
760
761 let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
762 .await?
763 .paginate(params);
764
765 loop {
766 let processed_event_ids = {
767 let event_ids = event_pages
768 .page
769 .data
770 .iter()
771 .map(|event| event.id.as_str())
772 .collect::<Vec<_>>();
773 app.db
774 .get_processed_stripe_events_by_event_ids(&event_ids)
775 .await?
776 .into_iter()
777 .map(|event| event.stripe_event_id)
778 .collect::<Vec<_>>()
779 };
780
781 let mut processed_events_in_page = 0;
782 let events_in_page = event_pages.page.data.len();
783 for event in &event_pages.page.data {
784 if processed_event_ids.contains(&event.id.to_string()) {
785 processed_events_in_page += 1;
786 log::debug!("Stripe events: already processed '{}', skipping", event.id);
787 } else {
788 unprocessed_events.push(event.clone());
789 }
790 }
791
792 if processed_events_in_page == events_in_page {
793 pages_of_already_processed_events += 1;
794 }
795
796 if event_pages.page.has_more {
797 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
798 {
799 log::info!(
800 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
801 );
802 break;
803 } else {
804 log::info!("Stripe events: retrieving next page");
805 event_pages = event_pages.next(&real_stripe_client).await?;
806 }
807 } else {
808 break;
809 }
810 }
811
812 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
813
814 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
815 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
816
817 for event in unprocessed_events {
818 let event_id = event.id.clone();
819 let processed_event_params = CreateProcessedStripeEventParams {
820 stripe_event_id: event.id.to_string(),
821 stripe_event_type: event_type_to_string(event.type_),
822 stripe_event_created_timestamp: event.created,
823 };
824
825 // If the event has happened too far in the past, we don't want to
826 // process it and risk overwriting other more-recent updates.
827 //
828 // 1 day was chosen arbitrarily. This could be made longer or shorter.
829 let one_day = Duration::from_secs(24 * 60 * 60);
830 let a_day_ago = Utc::now() - one_day;
831 if a_day_ago.timestamp() > event.created {
832 log::info!(
833 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
834 event_id
835 );
836 app.db
837 .create_processed_stripe_event(&processed_event_params)
838 .await?;
839
840 continue;
841 }
842
843 let process_result = match event.type_ {
844 EventType::CustomerCreated | EventType::CustomerUpdated => {
845 handle_customer_event(app, real_stripe_client, event).await
846 }
847 EventType::CustomerSubscriptionCreated
848 | EventType::CustomerSubscriptionUpdated
849 | EventType::CustomerSubscriptionPaused
850 | EventType::CustomerSubscriptionResumed
851 | EventType::CustomerSubscriptionDeleted => {
852 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
853 }
854 _ => Ok(()),
855 };
856
857 if let Some(()) = process_result
858 .with_context(|| format!("failed to process event {event_id} successfully"))
859 .log_err()
860 {
861 app.db
862 .create_processed_stripe_event(&processed_event_params)
863 .await?;
864 }
865 }
866
867 Ok(())
868}
869
870async fn handle_customer_event(
871 app: &Arc<AppState>,
872 _stripe_client: &stripe::Client,
873 event: stripe::Event,
874) -> anyhow::Result<()> {
875 let EventObject::Customer(customer) = event.data.object else {
876 bail!("unexpected event payload for {}", event.id);
877 };
878
879 log::info!("handling Stripe {} event: {}", event.type_, event.id);
880
881 let Some(email) = customer.email else {
882 log::info!("Stripe customer has no email: skipping");
883 return Ok(());
884 };
885
886 let Some(user) = app.db.get_user_by_email(&email).await? else {
887 log::info!("no user found for email: skipping");
888 return Ok(());
889 };
890
891 if let Some(existing_customer) = app
892 .db
893 .get_billing_customer_by_stripe_customer_id(&customer.id)
894 .await?
895 {
896 app.db
897 .update_billing_customer(
898 existing_customer.id,
899 &UpdateBillingCustomerParams {
900 // For now we just leave the information as-is, as it is not
901 // likely to change.
902 ..Default::default()
903 },
904 )
905 .await?;
906 } else {
907 app.db
908 .create_billing_customer(&CreateBillingCustomerParams {
909 user_id: user.id,
910 stripe_customer_id: customer.id.to_string(),
911 })
912 .await?;
913 }
914
915 Ok(())
916}
917
918async fn sync_subscription(
919 app: &Arc<AppState>,
920 stripe_client: &Arc<dyn StripeClient>,
921 subscription: StripeSubscription,
922) -> anyhow::Result<billing_customer::Model> {
923 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
924 stripe_billing
925 .determine_subscription_kind(&subscription)
926 .await
927 } else {
928 None
929 };
930
931 let billing_customer =
932 find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
933 .await?
934 .context("billing customer not found")?;
935
936 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
937 if subscription.status == SubscriptionStatus::Trialing {
938 let current_period_start =
939 DateTime::from_timestamp(subscription.current_period_start, 0)
940 .context("No trial subscription period start")?;
941
942 app.db
943 .update_billing_customer(
944 billing_customer.id,
945 &UpdateBillingCustomerParams {
946 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
947 ..Default::default()
948 },
949 )
950 .await?;
951 }
952 }
953
954 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
955 && subscription
956 .cancellation_details
957 .as_ref()
958 .and_then(|details| details.reason)
959 .map_or(false, |reason| {
960 reason == StripeCancellationDetailsReason::PaymentFailed
961 });
962
963 if was_canceled_due_to_payment_failure {
964 app.db
965 .update_billing_customer(
966 billing_customer.id,
967 &UpdateBillingCustomerParams {
968 has_overdue_invoices: ActiveValue::set(true),
969 ..Default::default()
970 },
971 )
972 .await?;
973 }
974
975 if let Some(existing_subscription) = app
976 .db
977 .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
978 .await?
979 {
980 app.db
981 .update_billing_subscription(
982 existing_subscription.id,
983 &UpdateBillingSubscriptionParams {
984 billing_customer_id: ActiveValue::set(billing_customer.id),
985 kind: ActiveValue::set(subscription_kind),
986 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
987 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
988 stripe_cancel_at: ActiveValue::set(
989 subscription
990 .cancel_at
991 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
992 .map(|time| time.naive_utc()),
993 ),
994 stripe_cancellation_reason: ActiveValue::set(
995 subscription
996 .cancellation_details
997 .and_then(|details| details.reason)
998 .map(|reason| reason.into()),
999 ),
1000 stripe_current_period_start: ActiveValue::set(Some(
1001 subscription.current_period_start,
1002 )),
1003 stripe_current_period_end: ActiveValue::set(Some(
1004 subscription.current_period_end,
1005 )),
1006 },
1007 )
1008 .await?;
1009 } else {
1010 if let Some(existing_subscription) = app
1011 .db
1012 .get_active_billing_subscription(billing_customer.user_id)
1013 .await?
1014 {
1015 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1016 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1017 {
1018 let stripe_subscription_id = StripeSubscriptionId(
1019 existing_subscription.stripe_subscription_id.clone().into(),
1020 );
1021
1022 stripe_client
1023 .cancel_subscription(&stripe_subscription_id)
1024 .await?;
1025 } else {
1026 // If the user already has an active billing subscription, ignore the
1027 // event and return an `Ok` to signal that it was processed
1028 // successfully.
1029 //
1030 // There is the possibility that this could cause us to not create a
1031 // subscription in the following scenario:
1032 //
1033 // 1. User has an active subscription A
1034 // 2. User cancels subscription A
1035 // 3. User creates a new subscription B
1036 // 4. We process the new subscription B before the cancellation of subscription A
1037 // 5. User ends up with no subscriptions
1038 //
1039 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1040
1041 log::info!(
1042 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1043 user_id = billing_customer.user_id,
1044 subscription_id = subscription.id
1045 );
1046 return Ok(billing_customer);
1047 }
1048 }
1049
1050 app.db
1051 .create_billing_subscription(&CreateBillingSubscriptionParams {
1052 billing_customer_id: billing_customer.id,
1053 kind: subscription_kind,
1054 stripe_subscription_id: subscription.id.to_string(),
1055 stripe_subscription_status: subscription.status.into(),
1056 stripe_cancellation_reason: subscription
1057 .cancellation_details
1058 .and_then(|details| details.reason)
1059 .map(|reason| reason.into()),
1060 stripe_current_period_start: Some(subscription.current_period_start),
1061 stripe_current_period_end: Some(subscription.current_period_end),
1062 })
1063 .await?;
1064 }
1065
1066 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1067 if subscription.status == SubscriptionStatus::Canceled
1068 || subscription.status == SubscriptionStatus::Paused
1069 {
1070 let already_has_active_billing_subscription = app
1071 .db
1072 .has_active_billing_subscription(billing_customer.user_id)
1073 .await?;
1074 if !already_has_active_billing_subscription {
1075 let stripe_customer_id =
1076 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1077
1078 stripe_billing
1079 .subscribe_to_zed_free(stripe_customer_id)
1080 .await?;
1081 }
1082 }
1083 }
1084
1085 Ok(billing_customer)
1086}
1087
1088async fn handle_customer_subscription_event(
1089 app: &Arc<AppState>,
1090 rpc_server: &Arc<Server>,
1091 stripe_client: &Arc<dyn StripeClient>,
1092 event: stripe::Event,
1093) -> anyhow::Result<()> {
1094 let EventObject::Subscription(subscription) = event.data.object else {
1095 bail!("unexpected event payload for {}", event.id);
1096 };
1097
1098 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1099
1100 let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
1101
1102 // When the user's subscription changes, push down any changes to their plan.
1103 rpc_server
1104 .update_plan_for_user(billing_customer.user_id)
1105 .await
1106 .trace_err();
1107
1108 // When the user's subscription changes, we want to refresh their LLM tokens
1109 // to either grant/revoke access.
1110 rpc_server
1111 .refresh_llm_tokens_for_user(billing_customer.user_id)
1112 .await;
1113
1114 Ok(())
1115}
1116
1117#[derive(Debug, Deserialize)]
1118struct GetCurrentUsageParams {
1119 github_user_id: i32,
1120}
1121
1122#[derive(Debug, Serialize)]
1123struct UsageCounts {
1124 pub used: i32,
1125 pub limit: Option<i32>,
1126 pub remaining: Option<i32>,
1127}
1128
1129#[derive(Debug, Serialize)]
1130struct ModelRequestUsage {
1131 pub model: String,
1132 pub mode: CompletionMode,
1133 pub requests: i32,
1134}
1135
1136#[derive(Debug, Serialize)]
1137struct CurrentUsage {
1138 pub model_requests: UsageCounts,
1139 pub model_request_usage: Vec<ModelRequestUsage>,
1140 pub edit_predictions: UsageCounts,
1141}
1142
1143#[derive(Debug, Default, Serialize)]
1144struct GetCurrentUsageResponse {
1145 pub plan: String,
1146 pub current_usage: Option<CurrentUsage>,
1147}
1148
1149async fn get_current_usage(
1150 Extension(app): Extension<Arc<AppState>>,
1151 Query(params): Query<GetCurrentUsageParams>,
1152) -> Result<Json<GetCurrentUsageResponse>> {
1153 let user = app
1154 .db
1155 .get_user_by_github_user_id(params.github_user_id)
1156 .await?
1157 .context("user not found")?;
1158
1159 let feature_flags = app.db.get_user_flags(user.id).await?;
1160 let has_extended_trial = feature_flags
1161 .iter()
1162 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1163
1164 let Some(llm_db) = app.llm_db.clone() else {
1165 return Err(Error::http(
1166 StatusCode::NOT_IMPLEMENTED,
1167 "LLM database not available".into(),
1168 ));
1169 };
1170
1171 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1172 return Ok(Json(GetCurrentUsageResponse::default()));
1173 };
1174
1175 let subscription_period = maybe!({
1176 let period_start_at = subscription.current_period_start_at()?;
1177 let period_end_at = subscription.current_period_end_at()?;
1178
1179 Some((period_start_at, period_end_at))
1180 });
1181
1182 let Some((period_start_at, period_end_at)) = subscription_period else {
1183 return Ok(Json(GetCurrentUsageResponse::default()));
1184 };
1185
1186 let usage = llm_db
1187 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1188 .await?;
1189
1190 let plan = subscription
1191 .kind
1192 .map(Into::into)
1193 .unwrap_or(zed_llm_client::Plan::ZedFree);
1194
1195 let model_requests_limit = match plan.model_requests_limit() {
1196 zed_llm_client::UsageLimit::Limited(limit) => {
1197 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1198 1_000
1199 } else {
1200 limit
1201 };
1202
1203 Some(limit)
1204 }
1205 zed_llm_client::UsageLimit::Unlimited => None,
1206 };
1207
1208 let edit_predictions_limit = match plan.edit_predictions_limit() {
1209 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1210 zed_llm_client::UsageLimit::Unlimited => None,
1211 };
1212
1213 let Some(usage) = usage else {
1214 return Ok(Json(GetCurrentUsageResponse {
1215 plan: plan.as_str().to_string(),
1216 current_usage: Some(CurrentUsage {
1217 model_requests: UsageCounts {
1218 used: 0,
1219 limit: model_requests_limit,
1220 remaining: model_requests_limit,
1221 },
1222 model_request_usage: Vec::new(),
1223 edit_predictions: UsageCounts {
1224 used: 0,
1225 limit: edit_predictions_limit,
1226 remaining: edit_predictions_limit,
1227 },
1228 }),
1229 }));
1230 };
1231
1232 let subscription_usage_meters = llm_db
1233 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1234 .await?;
1235
1236 let model_request_usage = subscription_usage_meters
1237 .into_iter()
1238 .filter_map(|(usage_meter, _usage)| {
1239 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1240
1241 Some(ModelRequestUsage {
1242 model: model.name.clone(),
1243 mode: usage_meter.mode,
1244 requests: usage_meter.requests,
1245 })
1246 })
1247 .collect::<Vec<_>>();
1248
1249 Ok(Json(GetCurrentUsageResponse {
1250 plan: plan.as_str().to_string(),
1251 current_usage: Some(CurrentUsage {
1252 model_requests: UsageCounts {
1253 used: usage.model_requests,
1254 limit: model_requests_limit,
1255 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1256 },
1257 model_request_usage,
1258 edit_predictions: UsageCounts {
1259 used: usage.edit_predictions,
1260 limit: edit_predictions_limit,
1261 remaining: edit_predictions_limit
1262 .map(|limit| (limit - usage.edit_predictions).max(0)),
1263 },
1264 }),
1265 }))
1266}
1267
1268impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1269 fn from(value: SubscriptionStatus) -> Self {
1270 match value {
1271 SubscriptionStatus::Incomplete => Self::Incomplete,
1272 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1273 SubscriptionStatus::Trialing => Self::Trialing,
1274 SubscriptionStatus::Active => Self::Active,
1275 SubscriptionStatus::PastDue => Self::PastDue,
1276 SubscriptionStatus::Canceled => Self::Canceled,
1277 SubscriptionStatus::Unpaid => Self::Unpaid,
1278 SubscriptionStatus::Paused => Self::Paused,
1279 }
1280 }
1281}
1282
1283impl From<CancellationDetailsReason> for StripeCancellationReason {
1284 fn from(value: CancellationDetailsReason) -> Self {
1285 match value {
1286 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1287 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1288 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1289 }
1290 }
1291}
1292
1293/// Finds or creates a billing customer using the provided customer.
1294pub async fn find_or_create_billing_customer(
1295 app: &Arc<AppState>,
1296 stripe_client: &dyn StripeClient,
1297 customer_id: &StripeCustomerId,
1298) -> anyhow::Result<Option<billing_customer::Model>> {
1299 // If we already have a billing customer record associated with the Stripe customer,
1300 // there's nothing more we need to do.
1301 if let Some(billing_customer) = app
1302 .db
1303 .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
1304 .await?
1305 {
1306 return Ok(Some(billing_customer));
1307 }
1308
1309 let customer = stripe_client.get_customer(customer_id).await?;
1310
1311 let Some(email) = customer.email else {
1312 return Ok(None);
1313 };
1314
1315 let Some(user) = app.db.get_user_by_email(&email).await? else {
1316 return Ok(None);
1317 };
1318
1319 let billing_customer = app
1320 .db
1321 .create_billing_customer(&CreateBillingCustomerParams {
1322 user_id: user.id,
1323 stripe_customer_id: customer.id.to_string(),
1324 })
1325 .await?;
1326
1327 Ok(Some(billing_customer))
1328}
1329
1330const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1331
1332pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1333 let Some(stripe_billing) = app.stripe_billing.clone() else {
1334 log::warn!("failed to retrieve Stripe billing object");
1335 return;
1336 };
1337 let Some(llm_db) = app.llm_db.clone() else {
1338 log::warn!("failed to retrieve LLM database");
1339 return;
1340 };
1341
1342 let executor = app.executor.clone();
1343 executor.spawn_detached({
1344 let executor = executor.clone();
1345 async move {
1346 loop {
1347 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1348 .await
1349 .context("failed to sync LLM request usage to Stripe")
1350 .trace_err();
1351 executor
1352 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1353 .await;
1354 }
1355 }
1356 });
1357}
1358
1359async fn sync_model_request_usage_with_stripe(
1360 app: &Arc<AppState>,
1361 llm_db: &Arc<LlmDatabase>,
1362 stripe_billing: &Arc<StripeBilling>,
1363) -> anyhow::Result<()> {
1364 log::info!("Stripe usage sync: Starting");
1365 let started_at = Utc::now();
1366
1367 let staff_users = app.db.get_staff_users().await?;
1368 let staff_user_ids = staff_users
1369 .iter()
1370 .map(|user| user.id)
1371 .collect::<HashSet<UserId>>();
1372
1373 let usage_meters = llm_db
1374 .get_current_subscription_usage_meters(Utc::now())
1375 .await?;
1376 let mut usage_meters_by_user_id =
1377 HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
1378 for (usage_meter, usage) in usage_meters {
1379 let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
1380 meters.push(usage_meter);
1381 }
1382
1383 log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
1384 let get_zed_pro_subscriptions_started_at = Utc::now();
1385 let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
1386 log::info!(
1387 "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
1388 billing_subscriptions.len(),
1389 Utc::now() - get_zed_pro_subscriptions_started_at
1390 );
1391
1392 let claude_sonnet_4 = stripe_billing
1393 .find_price_by_lookup_key("claude-sonnet-4-requests")
1394 .await?;
1395 let claude_sonnet_4_max = stripe_billing
1396 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1397 .await?;
1398 let claude_opus_4 = stripe_billing
1399 .find_price_by_lookup_key("claude-opus-4-requests")
1400 .await?;
1401 let claude_opus_4_max = stripe_billing
1402 .find_price_by_lookup_key("claude-opus-4-requests-max")
1403 .await?;
1404 let claude_3_5_sonnet = stripe_billing
1405 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1406 .await?;
1407 let claude_3_7_sonnet = stripe_billing
1408 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1409 .await?;
1410 let claude_3_7_sonnet_max = stripe_billing
1411 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1412 .await?;
1413
1414 let model_mode_combinations = [
1415 ("claude-opus-4", CompletionMode::Max),
1416 ("claude-opus-4", CompletionMode::Normal),
1417 ("claude-sonnet-4", CompletionMode::Max),
1418 ("claude-sonnet-4", CompletionMode::Normal),
1419 ("claude-3-7-sonnet", CompletionMode::Max),
1420 ("claude-3-7-sonnet", CompletionMode::Normal),
1421 ("claude-3-5-sonnet", CompletionMode::Normal),
1422 ];
1423
1424 let billing_subscription_count = billing_subscriptions.len();
1425
1426 log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
1427
1428 for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
1429 maybe!(async {
1430 if staff_user_ids.contains(&user_id) {
1431 return anyhow::Ok(());
1432 }
1433
1434 let stripe_customer_id =
1435 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1436 let stripe_subscription_id =
1437 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1438
1439 let usage_meters = usage_meters_by_user_id.get(&user_id);
1440
1441 for (model, mode) in &model_mode_combinations {
1442 let Ok(model) =
1443 llm_db.model(LanguageModelProvider::Anthropic, model)
1444 else {
1445 log::warn!("Failed to load model for user {user_id}: {model}");
1446 continue;
1447 };
1448
1449 let (price, meter_event_name) = match model.name.as_str() {
1450 "claude-opus-4" => match mode {
1451 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1452 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1453 },
1454 "claude-sonnet-4" => match mode {
1455 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1456 CompletionMode::Max => {
1457 (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
1458 }
1459 },
1460 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1461 "claude-3-7-sonnet" => match mode {
1462 CompletionMode::Normal => {
1463 (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
1464 }
1465 CompletionMode::Max => {
1466 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1467 }
1468 },
1469 model_name => {
1470 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1471 }
1472 };
1473
1474 let model_requests = usage_meters
1475 .and_then(|usage_meters| {
1476 usage_meters
1477 .iter()
1478 .find(|meter| meter.model_id == model.id && meter.mode == *mode)
1479 })
1480 .map(|usage_meter| usage_meter.requests)
1481 .unwrap_or(0);
1482
1483 if model_requests > 0 {
1484 stripe_billing
1485 .subscribe_to_price(&stripe_subscription_id, price)
1486 .await?;
1487 }
1488
1489 stripe_billing
1490 .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
1491 .await
1492 .with_context(|| {
1493 format!(
1494 "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
1495 )
1496 })?;
1497 }
1498
1499 Ok(())
1500 })
1501 .await
1502 .log_err();
1503 }
1504
1505 log::info!(
1506 "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
1507 Utc::now() - started_at
1508 );
1509
1510 Ok(())
1511}