1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::{HashMap, HashSet};
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
21 PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24use zed_llm_client::LanguageModelProvider;
25
26use crate::api::events::SnowflakeRow;
27use crate::db::billing_subscription::{
28 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
29};
30use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
31use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
32use crate::rpc::{ResultExt as _, Server};
33use crate::stripe_client::{
34 StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
35 StripeSubscriptionId, UpdateCustomerParams,
36};
37use crate::{AppState, Error, Result};
38use crate::{db::UserId, llm::db::LlmDatabase};
39use crate::{
40 db::{
41 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
42 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
43 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
44 },
45 stripe_billing::StripeBilling,
46};
47
48pub fn router() -> Router {
49 Router::new()
50 .route(
51 "/billing/preferences",
52 get(get_billing_preferences).put(update_billing_preferences),
53 )
54 .route(
55 "/billing/subscriptions",
56 get(list_billing_subscriptions).post(create_billing_subscription),
57 )
58 .route(
59 "/billing/subscriptions/manage",
60 post(manage_billing_subscription),
61 )
62 .route(
63 "/billing/subscriptions/sync",
64 post(sync_billing_subscription),
65 )
66 .route("/billing/usage", get(get_current_usage))
67}
68
69#[derive(Debug, Deserialize)]
70struct GetBillingPreferencesParams {
71 github_user_id: i32,
72}
73
74#[derive(Debug, Serialize)]
75struct BillingPreferencesResponse {
76 trial_started_at: Option<String>,
77 max_monthly_llm_usage_spending_in_cents: i32,
78 model_request_overages_enabled: bool,
79 model_request_overages_spend_limit_in_cents: i32,
80}
81
82async fn get_billing_preferences(
83 Extension(app): Extension<Arc<AppState>>,
84 Query(params): Query<GetBillingPreferencesParams>,
85) -> Result<Json<BillingPreferencesResponse>> {
86 let user = app
87 .db
88 .get_user_by_github_user_id(params.github_user_id)
89 .await?
90 .context("user not found")?;
91
92 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
93 let preferences = app.db.get_billing_preferences(user.id).await?;
94
95 Ok(Json(BillingPreferencesResponse {
96 trial_started_at: billing_customer
97 .and_then(|billing_customer| billing_customer.trial_started_at)
98 .map(|trial_started_at| {
99 trial_started_at
100 .and_utc()
101 .to_rfc3339_opts(SecondsFormat::Millis, true)
102 }),
103 max_monthly_llm_usage_spending_in_cents: preferences
104 .as_ref()
105 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
106 preferences.max_monthly_llm_usage_spending_in_cents
107 }),
108 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
109 preferences.model_request_overages_enabled
110 }),
111 model_request_overages_spend_limit_in_cents: preferences
112 .as_ref()
113 .map_or(0, |preferences| {
114 preferences.model_request_overages_spend_limit_in_cents
115 }),
116 }))
117}
118
119#[derive(Debug, Deserialize)]
120struct UpdateBillingPreferencesBody {
121 github_user_id: i32,
122 #[serde(default)]
123 max_monthly_llm_usage_spending_in_cents: i32,
124 #[serde(default)]
125 model_request_overages_enabled: bool,
126 #[serde(default)]
127 model_request_overages_spend_limit_in_cents: i32,
128}
129
130async fn update_billing_preferences(
131 Extension(app): Extension<Arc<AppState>>,
132 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
133 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
134) -> Result<Json<BillingPreferencesResponse>> {
135 let user = app
136 .db
137 .get_user_by_github_user_id(body.github_user_id)
138 .await?
139 .context("user not found")?;
140
141 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
142
143 let max_monthly_llm_usage_spending_in_cents =
144 body.max_monthly_llm_usage_spending_in_cents.max(0);
145 let model_request_overages_spend_limit_in_cents =
146 body.model_request_overages_spend_limit_in_cents.max(0);
147
148 let billing_preferences =
149 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
150 app.db
151 .update_billing_preferences(
152 user.id,
153 &UpdateBillingPreferencesParams {
154 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
155 max_monthly_llm_usage_spending_in_cents,
156 ),
157 model_request_overages_enabled: ActiveValue::set(
158 body.model_request_overages_enabled,
159 ),
160 model_request_overages_spend_limit_in_cents: ActiveValue::set(
161 model_request_overages_spend_limit_in_cents,
162 ),
163 },
164 )
165 .await?
166 } else {
167 app.db
168 .create_billing_preferences(
169 user.id,
170 &crate::db::CreateBillingPreferencesParams {
171 max_monthly_llm_usage_spending_in_cents,
172 model_request_overages_enabled: body.model_request_overages_enabled,
173 model_request_overages_spend_limit_in_cents,
174 },
175 )
176 .await?
177 };
178
179 SnowflakeRow::new(
180 "Billing Preferences Updated",
181 Some(user.metrics_id),
182 user.admin,
183 None,
184 json!({
185 "user_id": user.id,
186 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
187 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
188 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
189 }),
190 )
191 .write(&app.kinesis_client, &app.config.kinesis_stream)
192 .await
193 .log_err();
194
195 rpc_server.refresh_llm_tokens_for_user(user.id).await;
196
197 Ok(Json(BillingPreferencesResponse {
198 trial_started_at: billing_customer
199 .and_then(|billing_customer| billing_customer.trial_started_at)
200 .map(|trial_started_at| {
201 trial_started_at
202 .and_utc()
203 .to_rfc3339_opts(SecondsFormat::Millis, true)
204 }),
205 max_monthly_llm_usage_spending_in_cents: billing_preferences
206 .max_monthly_llm_usage_spending_in_cents,
207 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
208 model_request_overages_spend_limit_in_cents: billing_preferences
209 .model_request_overages_spend_limit_in_cents,
210 }))
211}
212
213#[derive(Debug, Deserialize)]
214struct ListBillingSubscriptionsParams {
215 github_user_id: i32,
216}
217
218#[derive(Debug, Serialize)]
219struct BillingSubscriptionJson {
220 id: BillingSubscriptionId,
221 name: String,
222 status: StripeSubscriptionStatus,
223 period: Option<BillingSubscriptionPeriodJson>,
224 trial_end_at: Option<String>,
225 cancel_at: Option<String>,
226 /// Whether this subscription can be canceled.
227 is_cancelable: bool,
228}
229
230#[derive(Debug, Serialize)]
231struct BillingSubscriptionPeriodJson {
232 start_at: String,
233 end_at: String,
234}
235
236#[derive(Debug, Serialize)]
237struct ListBillingSubscriptionsResponse {
238 subscriptions: Vec<BillingSubscriptionJson>,
239}
240
241async fn list_billing_subscriptions(
242 Extension(app): Extension<Arc<AppState>>,
243 Query(params): Query<ListBillingSubscriptionsParams>,
244) -> Result<Json<ListBillingSubscriptionsResponse>> {
245 let user = app
246 .db
247 .get_user_by_github_user_id(params.github_user_id)
248 .await?
249 .context("user not found")?;
250
251 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
252
253 Ok(Json(ListBillingSubscriptionsResponse {
254 subscriptions: subscriptions
255 .into_iter()
256 .map(|subscription| BillingSubscriptionJson {
257 id: subscription.id,
258 name: match subscription.kind {
259 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
260 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
261 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
262 None => "Zed LLM Usage".to_string(),
263 },
264 status: subscription.stripe_subscription_status,
265 period: maybe!({
266 let start_at = subscription.current_period_start_at()?;
267 let end_at = subscription.current_period_end_at()?;
268
269 Some(BillingSubscriptionPeriodJson {
270 start_at: start_at.to_rfc3339_opts(SecondsFormat::Millis, true),
271 end_at: end_at.to_rfc3339_opts(SecondsFormat::Millis, true),
272 })
273 }),
274 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
275 maybe!({
276 let end_at = subscription.stripe_current_period_end?;
277 let end_at = DateTime::from_timestamp(end_at, 0)?;
278
279 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
280 })
281 } else {
282 None
283 },
284 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
285 cancel_at
286 .and_utc()
287 .to_rfc3339_opts(SecondsFormat::Millis, true)
288 }),
289 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
290 && subscription.stripe_subscription_status.is_cancelable()
291 && subscription.stripe_cancel_at.is_none(),
292 })
293 .collect(),
294 }))
295}
296
297#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
298#[serde(rename_all = "snake_case")]
299enum ProductCode {
300 ZedPro,
301 ZedProTrial,
302}
303
304#[derive(Debug, Deserialize)]
305struct CreateBillingSubscriptionBody {
306 github_user_id: i32,
307 product: ProductCode,
308}
309
310#[derive(Debug, Serialize)]
311struct CreateBillingSubscriptionResponse {
312 checkout_session_url: String,
313}
314
315/// Initiates a Stripe Checkout session for creating a billing subscription.
316async fn create_billing_subscription(
317 Extension(app): Extension<Arc<AppState>>,
318 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
319) -> Result<Json<CreateBillingSubscriptionResponse>> {
320 let user = app
321 .db
322 .get_user_by_github_user_id(body.github_user_id)
323 .await?
324 .context("user not found")?;
325
326 let Some(stripe_billing) = app.stripe_billing.clone() else {
327 log::error!("failed to retrieve Stripe billing object");
328 Err(Error::http(
329 StatusCode::NOT_IMPLEMENTED,
330 "not supported".into(),
331 ))?
332 };
333
334 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
335 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
336 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
337
338 if !is_checkout_allowed {
339 return Err(Error::http(
340 StatusCode::CONFLICT,
341 "user already has an active subscription".into(),
342 ));
343 }
344 }
345
346 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
347 if let Some(existing_billing_customer) = &existing_billing_customer {
348 if existing_billing_customer.has_overdue_invoices {
349 return Err(Error::http(
350 StatusCode::PAYMENT_REQUIRED,
351 "user has overdue invoices".into(),
352 ));
353 }
354 }
355
356 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
357 let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
358 if let Some(email) = user.email_address.as_deref() {
359 stripe_billing
360 .client()
361 .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
362 .await
363 // Update of email address is best-effort - continue checkout even if it fails
364 .context("error updating stripe customer email address")
365 .log_err();
366 }
367 customer_id
368 } else {
369 stripe_billing
370 .find_or_create_customer_by_email(user.email_address.as_deref())
371 .await?
372 };
373
374 let success_url = format!(
375 "{}/account?checkout_complete=1",
376 app.config.zed_dot_dev_url()
377 );
378
379 let checkout_session_url = match body.product {
380 ProductCode::ZedPro => {
381 stripe_billing
382 .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
383 .await?
384 }
385 ProductCode::ZedProTrial => {
386 if let Some(existing_billing_customer) = &existing_billing_customer {
387 if existing_billing_customer.trial_started_at.is_some() {
388 return Err(Error::http(
389 StatusCode::FORBIDDEN,
390 "user already used free trial".into(),
391 ));
392 }
393 }
394
395 let feature_flags = app.db.get_user_flags(user.id).await?;
396
397 stripe_billing
398 .checkout_with_zed_pro_trial(
399 &customer_id,
400 &user.github_login,
401 feature_flags,
402 &success_url,
403 )
404 .await?
405 }
406 };
407
408 Ok(Json(CreateBillingSubscriptionResponse {
409 checkout_session_url,
410 }))
411}
412
413#[derive(Debug, PartialEq, Deserialize)]
414#[serde(rename_all = "snake_case")]
415enum ManageSubscriptionIntent {
416 /// The user intends to manage their subscription.
417 ///
418 /// This will open the Stripe billing portal without putting the user in a specific flow.
419 ManageSubscription,
420 /// The user intends to update their payment method.
421 UpdatePaymentMethod,
422 /// The user intends to upgrade to Zed Pro.
423 UpgradeToPro,
424 /// The user intends to cancel their subscription.
425 Cancel,
426 /// The user intends to stop the cancellation of their subscription.
427 StopCancellation,
428}
429
430#[derive(Debug, Deserialize)]
431struct ManageBillingSubscriptionBody {
432 github_user_id: i32,
433 intent: ManageSubscriptionIntent,
434 /// The ID of the subscription to manage.
435 subscription_id: BillingSubscriptionId,
436 redirect_to: Option<String>,
437}
438
439#[derive(Debug, Serialize)]
440struct ManageBillingSubscriptionResponse {
441 billing_portal_session_url: Option<String>,
442}
443
444/// Initiates a Stripe customer portal session for managing a billing subscription.
445async fn manage_billing_subscription(
446 Extension(app): Extension<Arc<AppState>>,
447 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
448) -> Result<Json<ManageBillingSubscriptionResponse>> {
449 let user = app
450 .db
451 .get_user_by_github_user_id(body.github_user_id)
452 .await?
453 .context("user not found")?;
454
455 let Some(stripe_client) = app.real_stripe_client.clone() else {
456 log::error!("failed to retrieve Stripe client");
457 Err(Error::http(
458 StatusCode::NOT_IMPLEMENTED,
459 "not supported".into(),
460 ))?
461 };
462
463 let Some(stripe_billing) = app.stripe_billing.clone() else {
464 log::error!("failed to retrieve Stripe billing object");
465 Err(Error::http(
466 StatusCode::NOT_IMPLEMENTED,
467 "not supported".into(),
468 ))?
469 };
470
471 let customer = app
472 .db
473 .get_billing_customer_by_user_id(user.id)
474 .await?
475 .context("billing customer not found")?;
476 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
477 .context("failed to parse customer ID")?;
478
479 let subscription = app
480 .db
481 .get_billing_subscription_by_id(body.subscription_id)
482 .await?
483 .context("subscription not found")?;
484 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
485 .context("failed to parse subscription ID")?;
486
487 if body.intent == ManageSubscriptionIntent::StopCancellation {
488 let updated_stripe_subscription = Subscription::update(
489 &stripe_client,
490 &subscription_id,
491 stripe::UpdateSubscription {
492 cancel_at_period_end: Some(false),
493 ..Default::default()
494 },
495 )
496 .await?;
497
498 app.db
499 .update_billing_subscription(
500 subscription.id,
501 &UpdateBillingSubscriptionParams {
502 stripe_cancel_at: ActiveValue::set(
503 updated_stripe_subscription
504 .cancel_at
505 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
506 .map(|time| time.naive_utc()),
507 ),
508 ..Default::default()
509 },
510 )
511 .await?;
512
513 return Ok(Json(ManageBillingSubscriptionResponse {
514 billing_portal_session_url: None,
515 }));
516 }
517
518 let flow = match body.intent {
519 ManageSubscriptionIntent::ManageSubscription => None,
520 ManageSubscriptionIntent::UpgradeToPro => {
521 let zed_pro_price_id: stripe::PriceId =
522 stripe_billing.zed_pro_price_id().await?.try_into()?;
523 let zed_free_price_id: stripe::PriceId =
524 stripe_billing.zed_free_price_id().await?.try_into()?;
525
526 let stripe_subscription =
527 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
528
529 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
530 && stripe_subscription.items.data.iter().any(|item| {
531 item.price
532 .as_ref()
533 .map_or(false, |price| price.id == zed_pro_price_id)
534 });
535 if is_on_zed_pro_trial {
536 let payment_methods = PaymentMethod::list(
537 &stripe_client,
538 &stripe::ListPaymentMethods {
539 customer: Some(stripe_subscription.customer.id()),
540 ..Default::default()
541 },
542 )
543 .await?;
544
545 let has_payment_method = !payment_methods.data.is_empty();
546 if !has_payment_method {
547 return Err(Error::http(
548 StatusCode::BAD_REQUEST,
549 "missing payment method".into(),
550 ));
551 }
552
553 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
554 Subscription::update(
555 &stripe_client,
556 &stripe_subscription.id,
557 stripe::UpdateSubscription {
558 trial_end: Some(stripe::Scheduled::now()),
559 ..Default::default()
560 },
561 )
562 .await?;
563
564 return Ok(Json(ManageBillingSubscriptionResponse {
565 billing_portal_session_url: None,
566 }));
567 }
568
569 let subscription_item_to_update = stripe_subscription
570 .items
571 .data
572 .iter()
573 .find_map(|item| {
574 let price = item.price.as_ref()?;
575
576 if price.id == zed_free_price_id {
577 Some(item.id.clone())
578 } else {
579 None
580 }
581 })
582 .context("No subscription item to update")?;
583
584 Some(CreateBillingPortalSessionFlowData {
585 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
586 subscription_update_confirm: Some(
587 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
588 subscription: subscription.stripe_subscription_id,
589 items: vec![
590 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
591 id: subscription_item_to_update.to_string(),
592 price: Some(zed_pro_price_id.to_string()),
593 quantity: Some(1),
594 },
595 ],
596 discounts: None,
597 },
598 ),
599 ..Default::default()
600 })
601 }
602 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
603 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
604 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
605 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
606 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
607 return_url: format!(
608 "{}{path}",
609 app.config.zed_dot_dev_url(),
610 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
611 ),
612 }),
613 ..Default::default()
614 }),
615 ..Default::default()
616 }),
617 ManageSubscriptionIntent::Cancel => {
618 if subscription.kind == Some(SubscriptionKind::ZedFree) {
619 return Err(Error::http(
620 StatusCode::BAD_REQUEST,
621 "free subscription cannot be canceled".into(),
622 ));
623 }
624
625 Some(CreateBillingPortalSessionFlowData {
626 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
627 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
628 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
629 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
630 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
631 }),
632 ..Default::default()
633 }),
634 subscription_cancel: Some(
635 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
636 subscription: subscription.stripe_subscription_id,
637 retention: None,
638 },
639 ),
640 ..Default::default()
641 })
642 }
643 ManageSubscriptionIntent::StopCancellation => unreachable!(),
644 };
645
646 let mut params = CreateBillingPortalSession::new(customer_id);
647 params.flow_data = flow;
648 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
649 params.return_url = Some(&return_url);
650
651 let session = BillingPortalSession::create(&stripe_client, params).await?;
652
653 Ok(Json(ManageBillingSubscriptionResponse {
654 billing_portal_session_url: Some(session.url),
655 }))
656}
657
658#[derive(Debug, Deserialize)]
659struct SyncBillingSubscriptionBody {
660 github_user_id: i32,
661}
662
663#[derive(Debug, Serialize)]
664struct SyncBillingSubscriptionResponse {
665 stripe_customer_id: String,
666}
667
668async fn sync_billing_subscription(
669 Extension(app): Extension<Arc<AppState>>,
670 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
671) -> Result<Json<SyncBillingSubscriptionResponse>> {
672 let Some(stripe_client) = app.stripe_client.clone() else {
673 log::error!("failed to retrieve Stripe client");
674 Err(Error::http(
675 StatusCode::NOT_IMPLEMENTED,
676 "not supported".into(),
677 ))?
678 };
679
680 let user = app
681 .db
682 .get_user_by_github_user_id(body.github_user_id)
683 .await?
684 .context("user not found")?;
685
686 let billing_customer = app
687 .db
688 .get_billing_customer_by_user_id(user.id)
689 .await?
690 .context("billing customer not found")?;
691 let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
692
693 let subscriptions = stripe_client
694 .list_subscriptions_for_customer(&stripe_customer_id)
695 .await?;
696
697 for subscription in subscriptions {
698 let subscription_id = subscription.id.clone();
699
700 sync_subscription(&app, &stripe_client, subscription)
701 .await
702 .with_context(|| {
703 format!(
704 "failed to sync subscription {subscription_id} for user {}",
705 user.id,
706 )
707 })?;
708 }
709
710 Ok(Json(SyncBillingSubscriptionResponse {
711 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
712 }))
713}
714
715/// The amount of time we wait in between each poll of Stripe events.
716///
717/// This value should strike a balance between:
718/// 1. Being short enough that we update quickly when something in Stripe changes
719/// 2. Being long enough that we don't eat into our rate limits.
720///
721/// As a point of reference, the Sequin folks say they have this at **500ms**:
722///
723/// > We poll the Stripe /events endpoint every 500ms per account
724/// >
725/// > — https://blog.sequinstream.com/events-not-webhooks/
726const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
727
728/// The maximum number of events to return per page.
729///
730/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
731///
732/// > Limit can range between 1 and 100, and the default is 10.
733const EVENTS_LIMIT_PER_PAGE: u64 = 100;
734
735/// The number of pages consisting entirely of already-processed events that we
736/// will see before we stop retrieving events.
737///
738/// This is used to prevent over-fetching the Stripe events API for events we've
739/// already seen and processed.
740const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
741
742/// Polls the Stripe events API periodically to reconcile the records in our
743/// database with the data in Stripe.
744pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
745 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
746 log::warn!("failed to retrieve Stripe client");
747 return;
748 };
749 let Some(stripe_client) = app.stripe_client.clone() else {
750 log::warn!("failed to retrieve Stripe client");
751 return;
752 };
753
754 let executor = app.executor.clone();
755 executor.spawn_detached({
756 let executor = executor.clone();
757 async move {
758 loop {
759 poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
760 .await
761 .log_err();
762
763 executor.sleep(POLL_EVENTS_INTERVAL).await;
764 }
765 }
766 });
767}
768
769async fn poll_stripe_events(
770 app: &Arc<AppState>,
771 rpc_server: &Arc<Server>,
772 stripe_client: &Arc<dyn StripeClient>,
773 real_stripe_client: &stripe::Client,
774) -> anyhow::Result<()> {
775 fn event_type_to_string(event_type: EventType) -> String {
776 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
777 // so we need to unquote it.
778 event_type.to_string().trim_matches('"').to_string()
779 }
780
781 let event_types = [
782 EventType::CustomerCreated,
783 EventType::CustomerUpdated,
784 EventType::CustomerSubscriptionCreated,
785 EventType::CustomerSubscriptionUpdated,
786 EventType::CustomerSubscriptionPaused,
787 EventType::CustomerSubscriptionResumed,
788 EventType::CustomerSubscriptionDeleted,
789 ]
790 .into_iter()
791 .map(event_type_to_string)
792 .collect::<Vec<_>>();
793
794 let mut pages_of_already_processed_events = 0;
795 let mut unprocessed_events = Vec::new();
796
797 log::info!(
798 "Stripe events: starting retrieval for {}",
799 event_types.join(", ")
800 );
801 let mut params = ListEvents::new();
802 params.types = Some(event_types.clone());
803 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
804
805 let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
806 .await?
807 .paginate(params);
808
809 loop {
810 let processed_event_ids = {
811 let event_ids = event_pages
812 .page
813 .data
814 .iter()
815 .map(|event| event.id.as_str())
816 .collect::<Vec<_>>();
817 app.db
818 .get_processed_stripe_events_by_event_ids(&event_ids)
819 .await?
820 .into_iter()
821 .map(|event| event.stripe_event_id)
822 .collect::<Vec<_>>()
823 };
824
825 let mut processed_events_in_page = 0;
826 let events_in_page = event_pages.page.data.len();
827 for event in &event_pages.page.data {
828 if processed_event_ids.contains(&event.id.to_string()) {
829 processed_events_in_page += 1;
830 log::debug!("Stripe events: already processed '{}', skipping", event.id);
831 } else {
832 unprocessed_events.push(event.clone());
833 }
834 }
835
836 if processed_events_in_page == events_in_page {
837 pages_of_already_processed_events += 1;
838 }
839
840 if event_pages.page.has_more {
841 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
842 {
843 log::info!(
844 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
845 );
846 break;
847 } else {
848 log::info!("Stripe events: retrieving next page");
849 event_pages = event_pages.next(&real_stripe_client).await?;
850 }
851 } else {
852 break;
853 }
854 }
855
856 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
857
858 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
859 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
860
861 for event in unprocessed_events {
862 let event_id = event.id.clone();
863 let processed_event_params = CreateProcessedStripeEventParams {
864 stripe_event_id: event.id.to_string(),
865 stripe_event_type: event_type_to_string(event.type_),
866 stripe_event_created_timestamp: event.created,
867 };
868
869 // If the event has happened too far in the past, we don't want to
870 // process it and risk overwriting other more-recent updates.
871 //
872 // 1 day was chosen arbitrarily. This could be made longer or shorter.
873 let one_day = Duration::from_secs(24 * 60 * 60);
874 let a_day_ago = Utc::now() - one_day;
875 if a_day_ago.timestamp() > event.created {
876 log::info!(
877 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
878 event_id
879 );
880 app.db
881 .create_processed_stripe_event(&processed_event_params)
882 .await?;
883
884 continue;
885 }
886
887 let process_result = match event.type_ {
888 EventType::CustomerCreated | EventType::CustomerUpdated => {
889 handle_customer_event(app, real_stripe_client, event).await
890 }
891 EventType::CustomerSubscriptionCreated
892 | EventType::CustomerSubscriptionUpdated
893 | EventType::CustomerSubscriptionPaused
894 | EventType::CustomerSubscriptionResumed
895 | EventType::CustomerSubscriptionDeleted => {
896 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
897 }
898 _ => Ok(()),
899 };
900
901 if let Some(()) = process_result
902 .with_context(|| format!("failed to process event {event_id} successfully"))
903 .log_err()
904 {
905 app.db
906 .create_processed_stripe_event(&processed_event_params)
907 .await?;
908 }
909 }
910
911 Ok(())
912}
913
914async fn handle_customer_event(
915 app: &Arc<AppState>,
916 _stripe_client: &stripe::Client,
917 event: stripe::Event,
918) -> anyhow::Result<()> {
919 let EventObject::Customer(customer) = event.data.object else {
920 bail!("unexpected event payload for {}", event.id);
921 };
922
923 log::info!("handling Stripe {} event: {}", event.type_, event.id);
924
925 let Some(email) = customer.email else {
926 log::info!("Stripe customer has no email: skipping");
927 return Ok(());
928 };
929
930 let Some(user) = app.db.get_user_by_email(&email).await? else {
931 log::info!("no user found for email: skipping");
932 return Ok(());
933 };
934
935 if let Some(existing_customer) = app
936 .db
937 .get_billing_customer_by_stripe_customer_id(&customer.id)
938 .await?
939 {
940 app.db
941 .update_billing_customer(
942 existing_customer.id,
943 &UpdateBillingCustomerParams {
944 // For now we just leave the information as-is, as it is not
945 // likely to change.
946 ..Default::default()
947 },
948 )
949 .await?;
950 } else {
951 app.db
952 .create_billing_customer(&CreateBillingCustomerParams {
953 user_id: user.id,
954 stripe_customer_id: customer.id.to_string(),
955 })
956 .await?;
957 }
958
959 Ok(())
960}
961
962async fn sync_subscription(
963 app: &Arc<AppState>,
964 stripe_client: &Arc<dyn StripeClient>,
965 subscription: StripeSubscription,
966) -> anyhow::Result<billing_customer::Model> {
967 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
968 stripe_billing
969 .determine_subscription_kind(&subscription)
970 .await
971 } else {
972 None
973 };
974
975 let billing_customer =
976 find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
977 .await?
978 .context("billing customer not found")?;
979
980 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
981 if subscription.status == SubscriptionStatus::Trialing {
982 let current_period_start =
983 DateTime::from_timestamp(subscription.current_period_start, 0)
984 .context("No trial subscription period start")?;
985
986 app.db
987 .update_billing_customer(
988 billing_customer.id,
989 &UpdateBillingCustomerParams {
990 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
991 ..Default::default()
992 },
993 )
994 .await?;
995 }
996 }
997
998 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
999 && subscription
1000 .cancellation_details
1001 .as_ref()
1002 .and_then(|details| details.reason)
1003 .map_or(false, |reason| {
1004 reason == StripeCancellationDetailsReason::PaymentFailed
1005 });
1006
1007 if was_canceled_due_to_payment_failure {
1008 app.db
1009 .update_billing_customer(
1010 billing_customer.id,
1011 &UpdateBillingCustomerParams {
1012 has_overdue_invoices: ActiveValue::set(true),
1013 ..Default::default()
1014 },
1015 )
1016 .await?;
1017 }
1018
1019 if let Some(existing_subscription) = app
1020 .db
1021 .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
1022 .await?
1023 {
1024 app.db
1025 .update_billing_subscription(
1026 existing_subscription.id,
1027 &UpdateBillingSubscriptionParams {
1028 billing_customer_id: ActiveValue::set(billing_customer.id),
1029 kind: ActiveValue::set(subscription_kind),
1030 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1031 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1032 stripe_cancel_at: ActiveValue::set(
1033 subscription
1034 .cancel_at
1035 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1036 .map(|time| time.naive_utc()),
1037 ),
1038 stripe_cancellation_reason: ActiveValue::set(
1039 subscription
1040 .cancellation_details
1041 .and_then(|details| details.reason)
1042 .map(|reason| reason.into()),
1043 ),
1044 stripe_current_period_start: ActiveValue::set(Some(
1045 subscription.current_period_start,
1046 )),
1047 stripe_current_period_end: ActiveValue::set(Some(
1048 subscription.current_period_end,
1049 )),
1050 },
1051 )
1052 .await?;
1053 } else {
1054 if let Some(existing_subscription) = app
1055 .db
1056 .get_active_billing_subscription(billing_customer.user_id)
1057 .await?
1058 {
1059 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1060 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1061 {
1062 let stripe_subscription_id = StripeSubscriptionId(
1063 existing_subscription.stripe_subscription_id.clone().into(),
1064 );
1065
1066 stripe_client
1067 .cancel_subscription(&stripe_subscription_id)
1068 .await?;
1069 } else {
1070 // If the user already has an active billing subscription, ignore the
1071 // event and return an `Ok` to signal that it was processed
1072 // successfully.
1073 //
1074 // There is the possibility that this could cause us to not create a
1075 // subscription in the following scenario:
1076 //
1077 // 1. User has an active subscription A
1078 // 2. User cancels subscription A
1079 // 3. User creates a new subscription B
1080 // 4. We process the new subscription B before the cancellation of subscription A
1081 // 5. User ends up with no subscriptions
1082 //
1083 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1084
1085 log::info!(
1086 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1087 user_id = billing_customer.user_id,
1088 subscription_id = subscription.id
1089 );
1090 return Ok(billing_customer);
1091 }
1092 }
1093
1094 app.db
1095 .create_billing_subscription(&CreateBillingSubscriptionParams {
1096 billing_customer_id: billing_customer.id,
1097 kind: subscription_kind,
1098 stripe_subscription_id: subscription.id.to_string(),
1099 stripe_subscription_status: subscription.status.into(),
1100 stripe_cancellation_reason: subscription
1101 .cancellation_details
1102 .and_then(|details| details.reason)
1103 .map(|reason| reason.into()),
1104 stripe_current_period_start: Some(subscription.current_period_start),
1105 stripe_current_period_end: Some(subscription.current_period_end),
1106 })
1107 .await?;
1108 }
1109
1110 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1111 if subscription.status == SubscriptionStatus::Canceled
1112 || subscription.status == SubscriptionStatus::Paused
1113 {
1114 let already_has_active_billing_subscription = app
1115 .db
1116 .has_active_billing_subscription(billing_customer.user_id)
1117 .await?;
1118 if !already_has_active_billing_subscription {
1119 let stripe_customer_id =
1120 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1121
1122 stripe_billing
1123 .subscribe_to_zed_free(stripe_customer_id)
1124 .await?;
1125 }
1126 }
1127 }
1128
1129 Ok(billing_customer)
1130}
1131
1132async fn handle_customer_subscription_event(
1133 app: &Arc<AppState>,
1134 rpc_server: &Arc<Server>,
1135 stripe_client: &Arc<dyn StripeClient>,
1136 event: stripe::Event,
1137) -> anyhow::Result<()> {
1138 let EventObject::Subscription(subscription) = event.data.object else {
1139 bail!("unexpected event payload for {}", event.id);
1140 };
1141
1142 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1143
1144 let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
1145
1146 // When the user's subscription changes, push down any changes to their plan.
1147 rpc_server
1148 .update_plan_for_user(billing_customer.user_id)
1149 .await
1150 .trace_err();
1151
1152 // When the user's subscription changes, we want to refresh their LLM tokens
1153 // to either grant/revoke access.
1154 rpc_server
1155 .refresh_llm_tokens_for_user(billing_customer.user_id)
1156 .await;
1157
1158 Ok(())
1159}
1160
1161#[derive(Debug, Deserialize)]
1162struct GetCurrentUsageParams {
1163 github_user_id: i32,
1164}
1165
1166#[derive(Debug, Serialize)]
1167struct UsageCounts {
1168 pub used: i32,
1169 pub limit: Option<i32>,
1170 pub remaining: Option<i32>,
1171}
1172
1173#[derive(Debug, Serialize)]
1174struct ModelRequestUsage {
1175 pub model: String,
1176 pub mode: CompletionMode,
1177 pub requests: i32,
1178}
1179
1180#[derive(Debug, Serialize)]
1181struct CurrentUsage {
1182 pub model_requests: UsageCounts,
1183 pub model_request_usage: Vec<ModelRequestUsage>,
1184 pub edit_predictions: UsageCounts,
1185}
1186
1187#[derive(Debug, Default, Serialize)]
1188struct GetCurrentUsageResponse {
1189 pub plan: String,
1190 pub current_usage: Option<CurrentUsage>,
1191}
1192
1193async fn get_current_usage(
1194 Extension(app): Extension<Arc<AppState>>,
1195 Query(params): Query<GetCurrentUsageParams>,
1196) -> Result<Json<GetCurrentUsageResponse>> {
1197 let user = app
1198 .db
1199 .get_user_by_github_user_id(params.github_user_id)
1200 .await?
1201 .context("user not found")?;
1202
1203 let feature_flags = app.db.get_user_flags(user.id).await?;
1204 let has_extended_trial = feature_flags
1205 .iter()
1206 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1207
1208 let Some(llm_db) = app.llm_db.clone() else {
1209 return Err(Error::http(
1210 StatusCode::NOT_IMPLEMENTED,
1211 "LLM database not available".into(),
1212 ));
1213 };
1214
1215 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1216 return Ok(Json(GetCurrentUsageResponse::default()));
1217 };
1218
1219 let subscription_period = maybe!({
1220 let period_start_at = subscription.current_period_start_at()?;
1221 let period_end_at = subscription.current_period_end_at()?;
1222
1223 Some((period_start_at, period_end_at))
1224 });
1225
1226 let Some((period_start_at, period_end_at)) = subscription_period else {
1227 return Ok(Json(GetCurrentUsageResponse::default()));
1228 };
1229
1230 let usage = llm_db
1231 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1232 .await?;
1233
1234 let plan = subscription
1235 .kind
1236 .map(Into::into)
1237 .unwrap_or(zed_llm_client::Plan::ZedFree);
1238
1239 let model_requests_limit = match plan.model_requests_limit() {
1240 zed_llm_client::UsageLimit::Limited(limit) => {
1241 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1242 1_000
1243 } else {
1244 limit
1245 };
1246
1247 Some(limit)
1248 }
1249 zed_llm_client::UsageLimit::Unlimited => None,
1250 };
1251
1252 let edit_predictions_limit = match plan.edit_predictions_limit() {
1253 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1254 zed_llm_client::UsageLimit::Unlimited => None,
1255 };
1256
1257 let Some(usage) = usage else {
1258 return Ok(Json(GetCurrentUsageResponse {
1259 plan: plan.as_str().to_string(),
1260 current_usage: Some(CurrentUsage {
1261 model_requests: UsageCounts {
1262 used: 0,
1263 limit: model_requests_limit,
1264 remaining: model_requests_limit,
1265 },
1266 model_request_usage: Vec::new(),
1267 edit_predictions: UsageCounts {
1268 used: 0,
1269 limit: edit_predictions_limit,
1270 remaining: edit_predictions_limit,
1271 },
1272 }),
1273 }));
1274 };
1275
1276 let subscription_usage_meters = llm_db
1277 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1278 .await?;
1279
1280 let model_request_usage = subscription_usage_meters
1281 .into_iter()
1282 .filter_map(|(usage_meter, _usage)| {
1283 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1284
1285 Some(ModelRequestUsage {
1286 model: model.name.clone(),
1287 mode: usage_meter.mode,
1288 requests: usage_meter.requests,
1289 })
1290 })
1291 .collect::<Vec<_>>();
1292
1293 Ok(Json(GetCurrentUsageResponse {
1294 plan: plan.as_str().to_string(),
1295 current_usage: Some(CurrentUsage {
1296 model_requests: UsageCounts {
1297 used: usage.model_requests,
1298 limit: model_requests_limit,
1299 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1300 },
1301 model_request_usage,
1302 edit_predictions: UsageCounts {
1303 used: usage.edit_predictions,
1304 limit: edit_predictions_limit,
1305 remaining: edit_predictions_limit
1306 .map(|limit| (limit - usage.edit_predictions).max(0)),
1307 },
1308 }),
1309 }))
1310}
1311
1312impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1313 fn from(value: SubscriptionStatus) -> Self {
1314 match value {
1315 SubscriptionStatus::Incomplete => Self::Incomplete,
1316 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1317 SubscriptionStatus::Trialing => Self::Trialing,
1318 SubscriptionStatus::Active => Self::Active,
1319 SubscriptionStatus::PastDue => Self::PastDue,
1320 SubscriptionStatus::Canceled => Self::Canceled,
1321 SubscriptionStatus::Unpaid => Self::Unpaid,
1322 SubscriptionStatus::Paused => Self::Paused,
1323 }
1324 }
1325}
1326
1327impl From<CancellationDetailsReason> for StripeCancellationReason {
1328 fn from(value: CancellationDetailsReason) -> Self {
1329 match value {
1330 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1331 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1332 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1333 }
1334 }
1335}
1336
1337/// Finds or creates a billing customer using the provided customer.
1338pub async fn find_or_create_billing_customer(
1339 app: &Arc<AppState>,
1340 stripe_client: &dyn StripeClient,
1341 customer_id: &StripeCustomerId,
1342) -> anyhow::Result<Option<billing_customer::Model>> {
1343 // If we already have a billing customer record associated with the Stripe customer,
1344 // there's nothing more we need to do.
1345 if let Some(billing_customer) = app
1346 .db
1347 .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
1348 .await?
1349 {
1350 return Ok(Some(billing_customer));
1351 }
1352
1353 let customer = stripe_client.get_customer(customer_id).await?;
1354
1355 let Some(email) = customer.email else {
1356 return Ok(None);
1357 };
1358
1359 let Some(user) = app.db.get_user_by_email(&email).await? else {
1360 return Ok(None);
1361 };
1362
1363 let billing_customer = app
1364 .db
1365 .create_billing_customer(&CreateBillingCustomerParams {
1366 user_id: user.id,
1367 stripe_customer_id: customer.id.to_string(),
1368 })
1369 .await?;
1370
1371 Ok(Some(billing_customer))
1372}
1373
1374const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1375
1376pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1377 let Some(stripe_billing) = app.stripe_billing.clone() else {
1378 log::warn!("failed to retrieve Stripe billing object");
1379 return;
1380 };
1381 let Some(llm_db) = app.llm_db.clone() else {
1382 log::warn!("failed to retrieve LLM database");
1383 return;
1384 };
1385
1386 let executor = app.executor.clone();
1387 executor.spawn_detached({
1388 let executor = executor.clone();
1389 async move {
1390 loop {
1391 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1392 .await
1393 .context("failed to sync LLM request usage to Stripe")
1394 .trace_err();
1395 executor
1396 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1397 .await;
1398 }
1399 }
1400 });
1401}
1402
1403async fn sync_model_request_usage_with_stripe(
1404 app: &Arc<AppState>,
1405 llm_db: &Arc<LlmDatabase>,
1406 stripe_billing: &Arc<StripeBilling>,
1407) -> anyhow::Result<()> {
1408 log::info!("Stripe usage sync: Starting");
1409 let started_at = Utc::now();
1410
1411 let staff_users = app.db.get_staff_users().await?;
1412 let staff_user_ids = staff_users
1413 .iter()
1414 .map(|user| user.id)
1415 .collect::<HashSet<UserId>>();
1416
1417 let usage_meters = llm_db
1418 .get_current_subscription_usage_meters(Utc::now())
1419 .await?;
1420 let mut usage_meters_by_user_id =
1421 HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
1422 for (usage_meter, usage) in usage_meters {
1423 let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
1424 meters.push(usage_meter);
1425 }
1426
1427 log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
1428 let get_zed_pro_subscriptions_started_at = Utc::now();
1429 let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
1430 log::info!(
1431 "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
1432 billing_subscriptions.len(),
1433 Utc::now() - get_zed_pro_subscriptions_started_at
1434 );
1435
1436 let claude_sonnet_4 = stripe_billing
1437 .find_price_by_lookup_key("claude-sonnet-4-requests")
1438 .await?;
1439 let claude_sonnet_4_max = stripe_billing
1440 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1441 .await?;
1442 let claude_opus_4 = stripe_billing
1443 .find_price_by_lookup_key("claude-opus-4-requests")
1444 .await?;
1445 let claude_opus_4_max = stripe_billing
1446 .find_price_by_lookup_key("claude-opus-4-requests-max")
1447 .await?;
1448 let claude_3_5_sonnet = stripe_billing
1449 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1450 .await?;
1451 let claude_3_7_sonnet = stripe_billing
1452 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1453 .await?;
1454 let claude_3_7_sonnet_max = stripe_billing
1455 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1456 .await?;
1457
1458 let model_mode_combinations = [
1459 ("claude-opus-4", CompletionMode::Max),
1460 ("claude-opus-4", CompletionMode::Normal),
1461 ("claude-sonnet-4", CompletionMode::Max),
1462 ("claude-sonnet-4", CompletionMode::Normal),
1463 ("claude-3-7-sonnet", CompletionMode::Max),
1464 ("claude-3-7-sonnet", CompletionMode::Normal),
1465 ("claude-3-5-sonnet", CompletionMode::Normal),
1466 ];
1467
1468 let billing_subscription_count = billing_subscriptions.len();
1469
1470 log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
1471
1472 for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
1473 maybe!(async {
1474 if staff_user_ids.contains(&user_id) {
1475 return anyhow::Ok(());
1476 }
1477
1478 let stripe_customer_id =
1479 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1480 let stripe_subscription_id =
1481 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1482
1483 let usage_meters = usage_meters_by_user_id.get(&user_id);
1484
1485 for (model, mode) in &model_mode_combinations {
1486 let Ok(model) =
1487 llm_db.model(LanguageModelProvider::Anthropic, model)
1488 else {
1489 log::warn!("Failed to load model for user {user_id}: {model}");
1490 continue;
1491 };
1492
1493 let (price, meter_event_name) = match model.name.as_str() {
1494 "claude-opus-4" => match mode {
1495 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1496 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1497 },
1498 "claude-sonnet-4" => match mode {
1499 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1500 CompletionMode::Max => {
1501 (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
1502 }
1503 },
1504 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1505 "claude-3-7-sonnet" => match mode {
1506 CompletionMode::Normal => {
1507 (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
1508 }
1509 CompletionMode::Max => {
1510 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1511 }
1512 },
1513 model_name => {
1514 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1515 }
1516 };
1517
1518 let model_requests = usage_meters
1519 .and_then(|usage_meters| {
1520 usage_meters
1521 .iter()
1522 .find(|meter| meter.model_id == model.id && meter.mode == *mode)
1523 })
1524 .map(|usage_meter| usage_meter.requests)
1525 .unwrap_or(0);
1526
1527 if model_requests > 0 {
1528 stripe_billing
1529 .subscribe_to_price(&stripe_subscription_id, price)
1530 .await?;
1531 }
1532
1533 stripe_billing
1534 .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
1535 .await
1536 .with_context(|| {
1537 format!(
1538 "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
1539 )
1540 })?;
1541 }
1542
1543 Ok(())
1544 })
1545 .await
1546 .log_err();
1547 }
1548
1549 log::info!(
1550 "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
1551 Utc::now() - started_at
1552 );
1553
1554 Ok(())
1555}