1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
21 PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
31use crate::rpc::{ResultExt as _, Server};
32use crate::stripe_client::{
33 StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
34 StripeSubscriptionId,
35};
36use crate::{AppState, Error, Result};
37use crate::{db::UserId, llm::db::LlmDatabase};
38use crate::{
39 db::{
40 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
41 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
42 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
43 },
44 stripe_billing::StripeBilling,
45};
46
47pub fn router() -> Router {
48 Router::new()
49 .route(
50 "/billing/preferences",
51 get(get_billing_preferences).put(update_billing_preferences),
52 )
53 .route(
54 "/billing/subscriptions",
55 get(list_billing_subscriptions).post(create_billing_subscription),
56 )
57 .route(
58 "/billing/subscriptions/manage",
59 post(manage_billing_subscription),
60 )
61 .route(
62 "/billing/subscriptions/migrate",
63 post(migrate_to_new_billing),
64 )
65 .route(
66 "/billing/subscriptions/sync",
67 post(sync_billing_subscription),
68 )
69 .route("/billing/usage", get(get_current_usage))
70}
71
72#[derive(Debug, Deserialize)]
73struct GetBillingPreferencesParams {
74 github_user_id: i32,
75}
76
77#[derive(Debug, Serialize)]
78struct BillingPreferencesResponse {
79 trial_started_at: Option<String>,
80 max_monthly_llm_usage_spending_in_cents: i32,
81 model_request_overages_enabled: bool,
82 model_request_overages_spend_limit_in_cents: i32,
83}
84
85async fn get_billing_preferences(
86 Extension(app): Extension<Arc<AppState>>,
87 Query(params): Query<GetBillingPreferencesParams>,
88) -> Result<Json<BillingPreferencesResponse>> {
89 let user = app
90 .db
91 .get_user_by_github_user_id(params.github_user_id)
92 .await?
93 .context("user not found")?;
94
95 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
96 let preferences = app.db.get_billing_preferences(user.id).await?;
97
98 Ok(Json(BillingPreferencesResponse {
99 trial_started_at: billing_customer
100 .and_then(|billing_customer| billing_customer.trial_started_at)
101 .map(|trial_started_at| {
102 trial_started_at
103 .and_utc()
104 .to_rfc3339_opts(SecondsFormat::Millis, true)
105 }),
106 max_monthly_llm_usage_spending_in_cents: preferences
107 .as_ref()
108 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
109 preferences.max_monthly_llm_usage_spending_in_cents
110 }),
111 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
112 preferences.model_request_overages_enabled
113 }),
114 model_request_overages_spend_limit_in_cents: preferences
115 .as_ref()
116 .map_or(0, |preferences| {
117 preferences.model_request_overages_spend_limit_in_cents
118 }),
119 }))
120}
121
122#[derive(Debug, Deserialize)]
123struct UpdateBillingPreferencesBody {
124 github_user_id: i32,
125 #[serde(default)]
126 max_monthly_llm_usage_spending_in_cents: i32,
127 #[serde(default)]
128 model_request_overages_enabled: bool,
129 #[serde(default)]
130 model_request_overages_spend_limit_in_cents: i32,
131}
132
133async fn update_billing_preferences(
134 Extension(app): Extension<Arc<AppState>>,
135 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
136 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
137) -> Result<Json<BillingPreferencesResponse>> {
138 let user = app
139 .db
140 .get_user_by_github_user_id(body.github_user_id)
141 .await?
142 .context("user not found")?;
143
144 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
145
146 let max_monthly_llm_usage_spending_in_cents =
147 body.max_monthly_llm_usage_spending_in_cents.max(0);
148 let model_request_overages_spend_limit_in_cents =
149 body.model_request_overages_spend_limit_in_cents.max(0);
150
151 let billing_preferences =
152 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
153 app.db
154 .update_billing_preferences(
155 user.id,
156 &UpdateBillingPreferencesParams {
157 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
158 max_monthly_llm_usage_spending_in_cents,
159 ),
160 model_request_overages_enabled: ActiveValue::set(
161 body.model_request_overages_enabled,
162 ),
163 model_request_overages_spend_limit_in_cents: ActiveValue::set(
164 model_request_overages_spend_limit_in_cents,
165 ),
166 },
167 )
168 .await?
169 } else {
170 app.db
171 .create_billing_preferences(
172 user.id,
173 &crate::db::CreateBillingPreferencesParams {
174 max_monthly_llm_usage_spending_in_cents,
175 model_request_overages_enabled: body.model_request_overages_enabled,
176 model_request_overages_spend_limit_in_cents,
177 },
178 )
179 .await?
180 };
181
182 SnowflakeRow::new(
183 "Billing Preferences Updated",
184 Some(user.metrics_id),
185 user.admin,
186 None,
187 json!({
188 "user_id": user.id,
189 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
190 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
191 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
192 }),
193 )
194 .write(&app.kinesis_client, &app.config.kinesis_stream)
195 .await
196 .log_err();
197
198 rpc_server.refresh_llm_tokens_for_user(user.id).await;
199
200 Ok(Json(BillingPreferencesResponse {
201 trial_started_at: billing_customer
202 .and_then(|billing_customer| billing_customer.trial_started_at)
203 .map(|trial_started_at| {
204 trial_started_at
205 .and_utc()
206 .to_rfc3339_opts(SecondsFormat::Millis, true)
207 }),
208 max_monthly_llm_usage_spending_in_cents: billing_preferences
209 .max_monthly_llm_usage_spending_in_cents,
210 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
211 model_request_overages_spend_limit_in_cents: billing_preferences
212 .model_request_overages_spend_limit_in_cents,
213 }))
214}
215
216#[derive(Debug, Deserialize)]
217struct ListBillingSubscriptionsParams {
218 github_user_id: i32,
219}
220
221#[derive(Debug, Serialize)]
222struct BillingSubscriptionJson {
223 id: BillingSubscriptionId,
224 name: String,
225 status: StripeSubscriptionStatus,
226 trial_end_at: Option<String>,
227 cancel_at: Option<String>,
228 /// Whether this subscription can be canceled.
229 is_cancelable: bool,
230}
231
232#[derive(Debug, Serialize)]
233struct ListBillingSubscriptionsResponse {
234 subscriptions: Vec<BillingSubscriptionJson>,
235}
236
237async fn list_billing_subscriptions(
238 Extension(app): Extension<Arc<AppState>>,
239 Query(params): Query<ListBillingSubscriptionsParams>,
240) -> Result<Json<ListBillingSubscriptionsResponse>> {
241 let user = app
242 .db
243 .get_user_by_github_user_id(params.github_user_id)
244 .await?
245 .context("user not found")?;
246
247 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
248
249 Ok(Json(ListBillingSubscriptionsResponse {
250 subscriptions: subscriptions
251 .into_iter()
252 .map(|subscription| BillingSubscriptionJson {
253 id: subscription.id,
254 name: match subscription.kind {
255 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
256 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
257 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
258 None => "Zed LLM Usage".to_string(),
259 },
260 status: subscription.stripe_subscription_status,
261 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
262 maybe!({
263 let end_at = subscription.stripe_current_period_end?;
264 let end_at = DateTime::from_timestamp(end_at, 0)?;
265
266 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
267 })
268 } else {
269 None
270 },
271 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
272 cancel_at
273 .and_utc()
274 .to_rfc3339_opts(SecondsFormat::Millis, true)
275 }),
276 is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
277 && subscription.stripe_subscription_status.is_cancelable()
278 && subscription.stripe_cancel_at.is_none(),
279 })
280 .collect(),
281 }))
282}
283
284#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
285#[serde(rename_all = "snake_case")]
286enum ProductCode {
287 ZedPro,
288 ZedProTrial,
289}
290
291#[derive(Debug, Deserialize)]
292struct CreateBillingSubscriptionBody {
293 github_user_id: i32,
294 product: ProductCode,
295}
296
297#[derive(Debug, Serialize)]
298struct CreateBillingSubscriptionResponse {
299 checkout_session_url: String,
300}
301
302/// Initiates a Stripe Checkout session for creating a billing subscription.
303async fn create_billing_subscription(
304 Extension(app): Extension<Arc<AppState>>,
305 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
306) -> Result<Json<CreateBillingSubscriptionResponse>> {
307 let user = app
308 .db
309 .get_user_by_github_user_id(body.github_user_id)
310 .await?
311 .context("user not found")?;
312
313 let Some(stripe_billing) = app.stripe_billing.clone() else {
314 log::error!("failed to retrieve Stripe billing object");
315 Err(Error::http(
316 StatusCode::NOT_IMPLEMENTED,
317 "not supported".into(),
318 ))?
319 };
320
321 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
322 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
323 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
324
325 if !is_checkout_allowed {
326 return Err(Error::http(
327 StatusCode::CONFLICT,
328 "user already has an active subscription".into(),
329 ));
330 }
331 }
332
333 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
334 if let Some(existing_billing_customer) = &existing_billing_customer {
335 if existing_billing_customer.has_overdue_invoices {
336 return Err(Error::http(
337 StatusCode::PAYMENT_REQUIRED,
338 "user has overdue invoices".into(),
339 ));
340 }
341 }
342
343 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
344 StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
345 } else {
346 stripe_billing
347 .find_or_create_customer_by_email(user.email_address.as_deref())
348 .await?
349 };
350
351 let success_url = format!(
352 "{}/account?checkout_complete=1",
353 app.config.zed_dot_dev_url()
354 );
355
356 let checkout_session_url = match body.product {
357 ProductCode::ZedPro => {
358 stripe_billing
359 .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
360 .await?
361 }
362 ProductCode::ZedProTrial => {
363 if let Some(existing_billing_customer) = &existing_billing_customer {
364 if existing_billing_customer.trial_started_at.is_some() {
365 return Err(Error::http(
366 StatusCode::FORBIDDEN,
367 "user already used free trial".into(),
368 ));
369 }
370 }
371
372 let feature_flags = app.db.get_user_flags(user.id).await?;
373
374 stripe_billing
375 .checkout_with_zed_pro_trial(
376 &customer_id,
377 &user.github_login,
378 feature_flags,
379 &success_url,
380 )
381 .await?
382 }
383 };
384
385 Ok(Json(CreateBillingSubscriptionResponse {
386 checkout_session_url,
387 }))
388}
389
390#[derive(Debug, PartialEq, Deserialize)]
391#[serde(rename_all = "snake_case")]
392enum ManageSubscriptionIntent {
393 /// The user intends to manage their subscription.
394 ///
395 /// This will open the Stripe billing portal without putting the user in a specific flow.
396 ManageSubscription,
397 /// The user intends to update their payment method.
398 UpdatePaymentMethod,
399 /// The user intends to upgrade to Zed Pro.
400 UpgradeToPro,
401 /// The user intends to cancel their subscription.
402 Cancel,
403 /// The user intends to stop the cancellation of their subscription.
404 StopCancellation,
405}
406
407#[derive(Debug, Deserialize)]
408struct ManageBillingSubscriptionBody {
409 github_user_id: i32,
410 intent: ManageSubscriptionIntent,
411 /// The ID of the subscription to manage.
412 subscription_id: BillingSubscriptionId,
413 redirect_to: Option<String>,
414}
415
416#[derive(Debug, Serialize)]
417struct ManageBillingSubscriptionResponse {
418 billing_portal_session_url: Option<String>,
419}
420
421/// Initiates a Stripe customer portal session for managing a billing subscription.
422async fn manage_billing_subscription(
423 Extension(app): Extension<Arc<AppState>>,
424 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
425) -> Result<Json<ManageBillingSubscriptionResponse>> {
426 let user = app
427 .db
428 .get_user_by_github_user_id(body.github_user_id)
429 .await?
430 .context("user not found")?;
431
432 let Some(stripe_client) = app.real_stripe_client.clone() else {
433 log::error!("failed to retrieve Stripe client");
434 Err(Error::http(
435 StatusCode::NOT_IMPLEMENTED,
436 "not supported".into(),
437 ))?
438 };
439
440 let Some(stripe_billing) = app.stripe_billing.clone() else {
441 log::error!("failed to retrieve Stripe billing object");
442 Err(Error::http(
443 StatusCode::NOT_IMPLEMENTED,
444 "not supported".into(),
445 ))?
446 };
447
448 let customer = app
449 .db
450 .get_billing_customer_by_user_id(user.id)
451 .await?
452 .context("billing customer not found")?;
453 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
454 .context("failed to parse customer ID")?;
455
456 let subscription = app
457 .db
458 .get_billing_subscription_by_id(body.subscription_id)
459 .await?
460 .context("subscription not found")?;
461 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
462 .context("failed to parse subscription ID")?;
463
464 if body.intent == ManageSubscriptionIntent::StopCancellation {
465 let updated_stripe_subscription = Subscription::update(
466 &stripe_client,
467 &subscription_id,
468 stripe::UpdateSubscription {
469 cancel_at_period_end: Some(false),
470 ..Default::default()
471 },
472 )
473 .await?;
474
475 app.db
476 .update_billing_subscription(
477 subscription.id,
478 &UpdateBillingSubscriptionParams {
479 stripe_cancel_at: ActiveValue::set(
480 updated_stripe_subscription
481 .cancel_at
482 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
483 .map(|time| time.naive_utc()),
484 ),
485 ..Default::default()
486 },
487 )
488 .await?;
489
490 return Ok(Json(ManageBillingSubscriptionResponse {
491 billing_portal_session_url: None,
492 }));
493 }
494
495 let flow = match body.intent {
496 ManageSubscriptionIntent::ManageSubscription => None,
497 ManageSubscriptionIntent::UpgradeToPro => {
498 let zed_pro_price_id: stripe::PriceId =
499 stripe_billing.zed_pro_price_id().await?.try_into()?;
500 let zed_free_price_id: stripe::PriceId =
501 stripe_billing.zed_free_price_id().await?.try_into()?;
502
503 let stripe_subscription =
504 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
505
506 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
507 && stripe_subscription.items.data.iter().any(|item| {
508 item.price
509 .as_ref()
510 .map_or(false, |price| price.id == zed_pro_price_id)
511 });
512 if is_on_zed_pro_trial {
513 let payment_methods = PaymentMethod::list(
514 &stripe_client,
515 &stripe::ListPaymentMethods {
516 customer: Some(stripe_subscription.customer.id()),
517 ..Default::default()
518 },
519 )
520 .await?;
521
522 let has_payment_method = !payment_methods.data.is_empty();
523 if !has_payment_method {
524 return Err(Error::http(
525 StatusCode::BAD_REQUEST,
526 "missing payment method".into(),
527 ));
528 }
529
530 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
531 Subscription::update(
532 &stripe_client,
533 &stripe_subscription.id,
534 stripe::UpdateSubscription {
535 trial_end: Some(stripe::Scheduled::now()),
536 ..Default::default()
537 },
538 )
539 .await?;
540
541 return Ok(Json(ManageBillingSubscriptionResponse {
542 billing_portal_session_url: None,
543 }));
544 }
545
546 let subscription_item_to_update = stripe_subscription
547 .items
548 .data
549 .iter()
550 .find_map(|item| {
551 let price = item.price.as_ref()?;
552
553 if price.id == zed_free_price_id {
554 Some(item.id.clone())
555 } else {
556 None
557 }
558 })
559 .context("No subscription item to update")?;
560
561 Some(CreateBillingPortalSessionFlowData {
562 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
563 subscription_update_confirm: Some(
564 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
565 subscription: subscription.stripe_subscription_id,
566 items: vec![
567 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
568 id: subscription_item_to_update.to_string(),
569 price: Some(zed_pro_price_id.to_string()),
570 quantity: Some(1),
571 },
572 ],
573 discounts: None,
574 },
575 ),
576 ..Default::default()
577 })
578 }
579 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
580 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
581 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
582 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
583 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
584 return_url: format!(
585 "{}{path}",
586 app.config.zed_dot_dev_url(),
587 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
588 ),
589 }),
590 ..Default::default()
591 }),
592 ..Default::default()
593 }),
594 ManageSubscriptionIntent::Cancel => {
595 if subscription.kind == Some(SubscriptionKind::ZedFree) {
596 return Err(Error::http(
597 StatusCode::BAD_REQUEST,
598 "free subscription cannot be canceled".into(),
599 ));
600 }
601
602 Some(CreateBillingPortalSessionFlowData {
603 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
604 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
605 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
606 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
607 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
608 }),
609 ..Default::default()
610 }),
611 subscription_cancel: Some(
612 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
613 subscription: subscription.stripe_subscription_id,
614 retention: None,
615 },
616 ),
617 ..Default::default()
618 })
619 }
620 ManageSubscriptionIntent::StopCancellation => unreachable!(),
621 };
622
623 let mut params = CreateBillingPortalSession::new(customer_id);
624 params.flow_data = flow;
625 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
626 params.return_url = Some(&return_url);
627
628 let session = BillingPortalSession::create(&stripe_client, params).await?;
629
630 Ok(Json(ManageBillingSubscriptionResponse {
631 billing_portal_session_url: Some(session.url),
632 }))
633}
634
635#[derive(Debug, Deserialize)]
636struct MigrateToNewBillingBody {
637 github_user_id: i32,
638}
639
640#[derive(Debug, Serialize)]
641struct MigrateToNewBillingResponse {
642 /// The ID of the subscription that was canceled.
643 canceled_subscription_id: Option<String>,
644}
645
646async fn migrate_to_new_billing(
647 Extension(app): Extension<Arc<AppState>>,
648 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
649) -> Result<Json<MigrateToNewBillingResponse>> {
650 let Some(stripe_client) = app.real_stripe_client.clone() else {
651 log::error!("failed to retrieve Stripe client");
652 Err(Error::http(
653 StatusCode::NOT_IMPLEMENTED,
654 "not supported".into(),
655 ))?
656 };
657
658 let user = app
659 .db
660 .get_user_by_github_user_id(body.github_user_id)
661 .await?
662 .context("user not found")?;
663
664 let old_billing_subscriptions_by_user = app
665 .db
666 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
667 .await?;
668
669 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
670 old_billing_subscriptions_by_user.get(&user.id)
671 {
672 let stripe_subscription_id = billing_subscription
673 .stripe_subscription_id
674 .parse::<stripe::SubscriptionId>()
675 .context("failed to parse Stripe subscription ID from database")?;
676
677 Subscription::cancel(
678 &stripe_client,
679 &stripe_subscription_id,
680 stripe::CancelSubscription {
681 invoice_now: Some(true),
682 ..Default::default()
683 },
684 )
685 .await?;
686
687 Some(stripe_subscription_id)
688 } else {
689 None
690 };
691
692 let all_feature_flags = app.db.list_feature_flags().await?;
693 let user_feature_flags = app.db.get_user_flags(user.id).await?;
694
695 for feature_flag in ["new-billing", "assistant2"] {
696 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
697 if already_in_feature_flag {
698 continue;
699 }
700
701 let feature_flag = all_feature_flags
702 .iter()
703 .find(|flag| flag.flag == feature_flag)
704 .context("failed to find feature flag: {feature_flag:?}")?;
705
706 app.db.add_user_flag(user.id, feature_flag.id).await?;
707 }
708
709 Ok(Json(MigrateToNewBillingResponse {
710 canceled_subscription_id: canceled_subscription_id
711 .map(|subscription_id| subscription_id.to_string()),
712 }))
713}
714
715#[derive(Debug, Deserialize)]
716struct SyncBillingSubscriptionBody {
717 github_user_id: i32,
718}
719
720#[derive(Debug, Serialize)]
721struct SyncBillingSubscriptionResponse {
722 stripe_customer_id: String,
723}
724
725async fn sync_billing_subscription(
726 Extension(app): Extension<Arc<AppState>>,
727 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
728) -> Result<Json<SyncBillingSubscriptionResponse>> {
729 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
730 log::error!("failed to retrieve Stripe client");
731 Err(Error::http(
732 StatusCode::NOT_IMPLEMENTED,
733 "not supported".into(),
734 ))?
735 };
736 let Some(stripe_client) = app.stripe_client.clone() else {
737 log::error!("failed to retrieve Stripe client");
738 Err(Error::http(
739 StatusCode::NOT_IMPLEMENTED,
740 "not supported".into(),
741 ))?
742 };
743
744 let user = app
745 .db
746 .get_user_by_github_user_id(body.github_user_id)
747 .await?
748 .context("user not found")?;
749
750 let billing_customer = app
751 .db
752 .get_billing_customer_by_user_id(user.id)
753 .await?
754 .context("billing customer not found")?;
755 let stripe_customer_id = billing_customer
756 .stripe_customer_id
757 .parse::<stripe::CustomerId>()
758 .context("failed to parse Stripe customer ID from database")?;
759
760 let subscriptions = Subscription::list(
761 &real_stripe_client,
762 &stripe::ListSubscriptions {
763 customer: Some(stripe_customer_id),
764 // Sync all non-canceled subscriptions.
765 status: None,
766 ..Default::default()
767 },
768 )
769 .await?;
770
771 for subscription in subscriptions.data {
772 let subscription_id = subscription.id.clone();
773
774 sync_subscription(&app, &stripe_client, subscription.into())
775 .await
776 .with_context(|| {
777 format!(
778 "failed to sync subscription {subscription_id} for user {}",
779 user.id,
780 )
781 })?;
782 }
783
784 Ok(Json(SyncBillingSubscriptionResponse {
785 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
786 }))
787}
788
789/// The amount of time we wait in between each poll of Stripe events.
790///
791/// This value should strike a balance between:
792/// 1. Being short enough that we update quickly when something in Stripe changes
793/// 2. Being long enough that we don't eat into our rate limits.
794///
795/// As a point of reference, the Sequin folks say they have this at **500ms**:
796///
797/// > We poll the Stripe /events endpoint every 500ms per account
798/// >
799/// > — https://blog.sequinstream.com/events-not-webhooks/
800const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
801
802/// The maximum number of events to return per page.
803///
804/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
805///
806/// > Limit can range between 1 and 100, and the default is 10.
807const EVENTS_LIMIT_PER_PAGE: u64 = 100;
808
809/// The number of pages consisting entirely of already-processed events that we
810/// will see before we stop retrieving events.
811///
812/// This is used to prevent over-fetching the Stripe events API for events we've
813/// already seen and processed.
814const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
815
816/// Polls the Stripe events API periodically to reconcile the records in our
817/// database with the data in Stripe.
818pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
819 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
820 log::warn!("failed to retrieve Stripe client");
821 return;
822 };
823 let Some(stripe_client) = app.stripe_client.clone() else {
824 log::warn!("failed to retrieve Stripe client");
825 return;
826 };
827
828 let executor = app.executor.clone();
829 executor.spawn_detached({
830 let executor = executor.clone();
831 async move {
832 loop {
833 poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
834 .await
835 .log_err();
836
837 executor.sleep(POLL_EVENTS_INTERVAL).await;
838 }
839 }
840 });
841}
842
843async fn poll_stripe_events(
844 app: &Arc<AppState>,
845 rpc_server: &Arc<Server>,
846 stripe_client: &Arc<dyn StripeClient>,
847 real_stripe_client: &stripe::Client,
848) -> anyhow::Result<()> {
849 fn event_type_to_string(event_type: EventType) -> String {
850 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
851 // so we need to unquote it.
852 event_type.to_string().trim_matches('"').to_string()
853 }
854
855 let event_types = [
856 EventType::CustomerCreated,
857 EventType::CustomerUpdated,
858 EventType::CustomerSubscriptionCreated,
859 EventType::CustomerSubscriptionUpdated,
860 EventType::CustomerSubscriptionPaused,
861 EventType::CustomerSubscriptionResumed,
862 EventType::CustomerSubscriptionDeleted,
863 ]
864 .into_iter()
865 .map(event_type_to_string)
866 .collect::<Vec<_>>();
867
868 let mut pages_of_already_processed_events = 0;
869 let mut unprocessed_events = Vec::new();
870
871 log::info!(
872 "Stripe events: starting retrieval for {}",
873 event_types.join(", ")
874 );
875 let mut params = ListEvents::new();
876 params.types = Some(event_types.clone());
877 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
878
879 let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
880 .await?
881 .paginate(params);
882
883 loop {
884 let processed_event_ids = {
885 let event_ids = event_pages
886 .page
887 .data
888 .iter()
889 .map(|event| event.id.as_str())
890 .collect::<Vec<_>>();
891 app.db
892 .get_processed_stripe_events_by_event_ids(&event_ids)
893 .await?
894 .into_iter()
895 .map(|event| event.stripe_event_id)
896 .collect::<Vec<_>>()
897 };
898
899 let mut processed_events_in_page = 0;
900 let events_in_page = event_pages.page.data.len();
901 for event in &event_pages.page.data {
902 if processed_event_ids.contains(&event.id.to_string()) {
903 processed_events_in_page += 1;
904 log::debug!("Stripe events: already processed '{}', skipping", event.id);
905 } else {
906 unprocessed_events.push(event.clone());
907 }
908 }
909
910 if processed_events_in_page == events_in_page {
911 pages_of_already_processed_events += 1;
912 }
913
914 if event_pages.page.has_more {
915 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
916 {
917 log::info!(
918 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
919 );
920 break;
921 } else {
922 log::info!("Stripe events: retrieving next page");
923 event_pages = event_pages.next(&real_stripe_client).await?;
924 }
925 } else {
926 break;
927 }
928 }
929
930 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
931
932 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
933 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
934
935 for event in unprocessed_events {
936 let event_id = event.id.clone();
937 let processed_event_params = CreateProcessedStripeEventParams {
938 stripe_event_id: event.id.to_string(),
939 stripe_event_type: event_type_to_string(event.type_),
940 stripe_event_created_timestamp: event.created,
941 };
942
943 // If the event has happened too far in the past, we don't want to
944 // process it and risk overwriting other more-recent updates.
945 //
946 // 1 day was chosen arbitrarily. This could be made longer or shorter.
947 let one_day = Duration::from_secs(24 * 60 * 60);
948 let a_day_ago = Utc::now() - one_day;
949 if a_day_ago.timestamp() > event.created {
950 log::info!(
951 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
952 event_id
953 );
954 app.db
955 .create_processed_stripe_event(&processed_event_params)
956 .await?;
957
958 continue;
959 }
960
961 let process_result = match event.type_ {
962 EventType::CustomerCreated | EventType::CustomerUpdated => {
963 handle_customer_event(app, real_stripe_client, event).await
964 }
965 EventType::CustomerSubscriptionCreated
966 | EventType::CustomerSubscriptionUpdated
967 | EventType::CustomerSubscriptionPaused
968 | EventType::CustomerSubscriptionResumed
969 | EventType::CustomerSubscriptionDeleted => {
970 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
971 }
972 _ => Ok(()),
973 };
974
975 if let Some(()) = process_result
976 .with_context(|| format!("failed to process event {event_id} successfully"))
977 .log_err()
978 {
979 app.db
980 .create_processed_stripe_event(&processed_event_params)
981 .await?;
982 }
983 }
984
985 Ok(())
986}
987
988async fn handle_customer_event(
989 app: &Arc<AppState>,
990 _stripe_client: &stripe::Client,
991 event: stripe::Event,
992) -> anyhow::Result<()> {
993 let EventObject::Customer(customer) = event.data.object else {
994 bail!("unexpected event payload for {}", event.id);
995 };
996
997 log::info!("handling Stripe {} event: {}", event.type_, event.id);
998
999 let Some(email) = customer.email else {
1000 log::info!("Stripe customer has no email: skipping");
1001 return Ok(());
1002 };
1003
1004 let Some(user) = app.db.get_user_by_email(&email).await? else {
1005 log::info!("no user found for email: skipping");
1006 return Ok(());
1007 };
1008
1009 if let Some(existing_customer) = app
1010 .db
1011 .get_billing_customer_by_stripe_customer_id(&customer.id)
1012 .await?
1013 {
1014 app.db
1015 .update_billing_customer(
1016 existing_customer.id,
1017 &UpdateBillingCustomerParams {
1018 // For now we just leave the information as-is, as it is not
1019 // likely to change.
1020 ..Default::default()
1021 },
1022 )
1023 .await?;
1024 } else {
1025 app.db
1026 .create_billing_customer(&CreateBillingCustomerParams {
1027 user_id: user.id,
1028 stripe_customer_id: customer.id.to_string(),
1029 })
1030 .await?;
1031 }
1032
1033 Ok(())
1034}
1035
1036async fn sync_subscription(
1037 app: &Arc<AppState>,
1038 stripe_client: &Arc<dyn StripeClient>,
1039 subscription: StripeSubscription,
1040) -> anyhow::Result<billing_customer::Model> {
1041 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1042 stripe_billing
1043 .determine_subscription_kind(&subscription)
1044 .await
1045 } else {
1046 None
1047 };
1048
1049 let billing_customer =
1050 find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
1051 .await?
1052 .context("billing customer not found")?;
1053
1054 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1055 if subscription.status == SubscriptionStatus::Trialing {
1056 let current_period_start =
1057 DateTime::from_timestamp(subscription.current_period_start, 0)
1058 .context("No trial subscription period start")?;
1059
1060 app.db
1061 .update_billing_customer(
1062 billing_customer.id,
1063 &UpdateBillingCustomerParams {
1064 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1065 ..Default::default()
1066 },
1067 )
1068 .await?;
1069 }
1070 }
1071
1072 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1073 && subscription
1074 .cancellation_details
1075 .as_ref()
1076 .and_then(|details| details.reason)
1077 .map_or(false, |reason| {
1078 reason == StripeCancellationDetailsReason::PaymentFailed
1079 });
1080
1081 if was_canceled_due_to_payment_failure {
1082 app.db
1083 .update_billing_customer(
1084 billing_customer.id,
1085 &UpdateBillingCustomerParams {
1086 has_overdue_invoices: ActiveValue::set(true),
1087 ..Default::default()
1088 },
1089 )
1090 .await?;
1091 }
1092
1093 if let Some(existing_subscription) = app
1094 .db
1095 .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
1096 .await?
1097 {
1098 app.db
1099 .update_billing_subscription(
1100 existing_subscription.id,
1101 &UpdateBillingSubscriptionParams {
1102 billing_customer_id: ActiveValue::set(billing_customer.id),
1103 kind: ActiveValue::set(subscription_kind),
1104 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1105 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1106 stripe_cancel_at: ActiveValue::set(
1107 subscription
1108 .cancel_at
1109 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1110 .map(|time| time.naive_utc()),
1111 ),
1112 stripe_cancellation_reason: ActiveValue::set(
1113 subscription
1114 .cancellation_details
1115 .and_then(|details| details.reason)
1116 .map(|reason| reason.into()),
1117 ),
1118 stripe_current_period_start: ActiveValue::set(Some(
1119 subscription.current_period_start,
1120 )),
1121 stripe_current_period_end: ActiveValue::set(Some(
1122 subscription.current_period_end,
1123 )),
1124 },
1125 )
1126 .await?;
1127 } else {
1128 if let Some(existing_subscription) = app
1129 .db
1130 .get_active_billing_subscription(billing_customer.user_id)
1131 .await?
1132 {
1133 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1134 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1135 {
1136 let stripe_subscription_id = StripeSubscriptionId(
1137 existing_subscription.stripe_subscription_id.clone().into(),
1138 );
1139
1140 stripe_client
1141 .cancel_subscription(&stripe_subscription_id)
1142 .await?;
1143 } else {
1144 // If the user already has an active billing subscription, ignore the
1145 // event and return an `Ok` to signal that it was processed
1146 // successfully.
1147 //
1148 // There is the possibility that this could cause us to not create a
1149 // subscription in the following scenario:
1150 //
1151 // 1. User has an active subscription A
1152 // 2. User cancels subscription A
1153 // 3. User creates a new subscription B
1154 // 4. We process the new subscription B before the cancellation of subscription A
1155 // 5. User ends up with no subscriptions
1156 //
1157 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1158
1159 log::info!(
1160 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1161 user_id = billing_customer.user_id,
1162 subscription_id = subscription.id
1163 );
1164 return Ok(billing_customer);
1165 }
1166 }
1167
1168 app.db
1169 .create_billing_subscription(&CreateBillingSubscriptionParams {
1170 billing_customer_id: billing_customer.id,
1171 kind: subscription_kind,
1172 stripe_subscription_id: subscription.id.to_string(),
1173 stripe_subscription_status: subscription.status.into(),
1174 stripe_cancellation_reason: subscription
1175 .cancellation_details
1176 .and_then(|details| details.reason)
1177 .map(|reason| reason.into()),
1178 stripe_current_period_start: Some(subscription.current_period_start),
1179 stripe_current_period_end: Some(subscription.current_period_end),
1180 })
1181 .await?;
1182 }
1183
1184 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1185 if subscription.status == SubscriptionStatus::Canceled
1186 || subscription.status == SubscriptionStatus::Paused
1187 {
1188 let already_has_active_billing_subscription = app
1189 .db
1190 .has_active_billing_subscription(billing_customer.user_id)
1191 .await?;
1192 if !already_has_active_billing_subscription {
1193 let stripe_customer_id =
1194 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1195
1196 stripe_billing
1197 .subscribe_to_zed_free(stripe_customer_id)
1198 .await?;
1199 }
1200 }
1201 }
1202
1203 Ok(billing_customer)
1204}
1205
1206async fn handle_customer_subscription_event(
1207 app: &Arc<AppState>,
1208 rpc_server: &Arc<Server>,
1209 stripe_client: &Arc<dyn StripeClient>,
1210 event: stripe::Event,
1211) -> anyhow::Result<()> {
1212 let EventObject::Subscription(subscription) = event.data.object else {
1213 bail!("unexpected event payload for {}", event.id);
1214 };
1215
1216 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1217
1218 let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
1219
1220 // When the user's subscription changes, push down any changes to their plan.
1221 rpc_server
1222 .update_plan_for_user(billing_customer.user_id)
1223 .await
1224 .trace_err();
1225
1226 // When the user's subscription changes, we want to refresh their LLM tokens
1227 // to either grant/revoke access.
1228 rpc_server
1229 .refresh_llm_tokens_for_user(billing_customer.user_id)
1230 .await;
1231
1232 Ok(())
1233}
1234
1235#[derive(Debug, Deserialize)]
1236struct GetCurrentUsageParams {
1237 github_user_id: i32,
1238}
1239
1240#[derive(Debug, Serialize)]
1241struct UsageCounts {
1242 pub used: i32,
1243 pub limit: Option<i32>,
1244 pub remaining: Option<i32>,
1245}
1246
1247#[derive(Debug, Serialize)]
1248struct ModelRequestUsage {
1249 pub model: String,
1250 pub mode: CompletionMode,
1251 pub requests: i32,
1252}
1253
1254#[derive(Debug, Serialize)]
1255struct CurrentUsage {
1256 pub model_requests: UsageCounts,
1257 pub model_request_usage: Vec<ModelRequestUsage>,
1258 pub edit_predictions: UsageCounts,
1259}
1260
1261#[derive(Debug, Default, Serialize)]
1262struct GetCurrentUsageResponse {
1263 pub plan: String,
1264 pub current_usage: Option<CurrentUsage>,
1265}
1266
1267async fn get_current_usage(
1268 Extension(app): Extension<Arc<AppState>>,
1269 Query(params): Query<GetCurrentUsageParams>,
1270) -> Result<Json<GetCurrentUsageResponse>> {
1271 let user = app
1272 .db
1273 .get_user_by_github_user_id(params.github_user_id)
1274 .await?
1275 .context("user not found")?;
1276
1277 let feature_flags = app.db.get_user_flags(user.id).await?;
1278 let has_extended_trial = feature_flags
1279 .iter()
1280 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1281
1282 let Some(llm_db) = app.llm_db.clone() else {
1283 return Err(Error::http(
1284 StatusCode::NOT_IMPLEMENTED,
1285 "LLM database not available".into(),
1286 ));
1287 };
1288
1289 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1290 return Ok(Json(GetCurrentUsageResponse::default()));
1291 };
1292
1293 let subscription_period = maybe!({
1294 let period_start_at = subscription.current_period_start_at()?;
1295 let period_end_at = subscription.current_period_end_at()?;
1296
1297 Some((period_start_at, period_end_at))
1298 });
1299
1300 let Some((period_start_at, period_end_at)) = subscription_period else {
1301 return Ok(Json(GetCurrentUsageResponse::default()));
1302 };
1303
1304 let usage = llm_db
1305 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1306 .await?;
1307
1308 let plan = subscription
1309 .kind
1310 .map(Into::into)
1311 .unwrap_or(zed_llm_client::Plan::ZedFree);
1312
1313 let model_requests_limit = match plan.model_requests_limit() {
1314 zed_llm_client::UsageLimit::Limited(limit) => {
1315 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1316 1_000
1317 } else {
1318 limit
1319 };
1320
1321 Some(limit)
1322 }
1323 zed_llm_client::UsageLimit::Unlimited => None,
1324 };
1325
1326 let edit_predictions_limit = match plan.edit_predictions_limit() {
1327 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1328 zed_llm_client::UsageLimit::Unlimited => None,
1329 };
1330
1331 let Some(usage) = usage else {
1332 return Ok(Json(GetCurrentUsageResponse {
1333 plan: plan.as_str().to_string(),
1334 current_usage: Some(CurrentUsage {
1335 model_requests: UsageCounts {
1336 used: 0,
1337 limit: model_requests_limit,
1338 remaining: model_requests_limit,
1339 },
1340 model_request_usage: Vec::new(),
1341 edit_predictions: UsageCounts {
1342 used: 0,
1343 limit: edit_predictions_limit,
1344 remaining: edit_predictions_limit,
1345 },
1346 }),
1347 }));
1348 };
1349
1350 let subscription_usage_meters = llm_db
1351 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1352 .await?;
1353
1354 let model_request_usage = subscription_usage_meters
1355 .into_iter()
1356 .filter_map(|(usage_meter, _usage)| {
1357 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1358
1359 Some(ModelRequestUsage {
1360 model: model.name.clone(),
1361 mode: usage_meter.mode,
1362 requests: usage_meter.requests,
1363 })
1364 })
1365 .collect::<Vec<_>>();
1366
1367 Ok(Json(GetCurrentUsageResponse {
1368 plan: plan.as_str().to_string(),
1369 current_usage: Some(CurrentUsage {
1370 model_requests: UsageCounts {
1371 used: usage.model_requests,
1372 limit: model_requests_limit,
1373 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1374 },
1375 model_request_usage,
1376 edit_predictions: UsageCounts {
1377 used: usage.edit_predictions,
1378 limit: edit_predictions_limit,
1379 remaining: edit_predictions_limit
1380 .map(|limit| (limit - usage.edit_predictions).max(0)),
1381 },
1382 }),
1383 }))
1384}
1385
1386impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1387 fn from(value: SubscriptionStatus) -> Self {
1388 match value {
1389 SubscriptionStatus::Incomplete => Self::Incomplete,
1390 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1391 SubscriptionStatus::Trialing => Self::Trialing,
1392 SubscriptionStatus::Active => Self::Active,
1393 SubscriptionStatus::PastDue => Self::PastDue,
1394 SubscriptionStatus::Canceled => Self::Canceled,
1395 SubscriptionStatus::Unpaid => Self::Unpaid,
1396 SubscriptionStatus::Paused => Self::Paused,
1397 }
1398 }
1399}
1400
1401impl From<CancellationDetailsReason> for StripeCancellationReason {
1402 fn from(value: CancellationDetailsReason) -> Self {
1403 match value {
1404 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1405 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1406 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1407 }
1408 }
1409}
1410
1411/// Finds or creates a billing customer using the provided customer.
1412pub async fn find_or_create_billing_customer(
1413 app: &Arc<AppState>,
1414 stripe_client: &dyn StripeClient,
1415 customer_id: &StripeCustomerId,
1416) -> anyhow::Result<Option<billing_customer::Model>> {
1417 // If we already have a billing customer record associated with the Stripe customer,
1418 // there's nothing more we need to do.
1419 if let Some(billing_customer) = app
1420 .db
1421 .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
1422 .await?
1423 {
1424 return Ok(Some(billing_customer));
1425 }
1426
1427 let customer = stripe_client.get_customer(customer_id).await?;
1428
1429 let Some(email) = customer.email else {
1430 return Ok(None);
1431 };
1432
1433 let Some(user) = app.db.get_user_by_email(&email).await? else {
1434 return Ok(None);
1435 };
1436
1437 let billing_customer = app
1438 .db
1439 .create_billing_customer(&CreateBillingCustomerParams {
1440 user_id: user.id,
1441 stripe_customer_id: customer.id.to_string(),
1442 })
1443 .await?;
1444
1445 Ok(Some(billing_customer))
1446}
1447
1448const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1449
1450pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1451 let Some(stripe_billing) = app.stripe_billing.clone() else {
1452 log::warn!("failed to retrieve Stripe billing object");
1453 return;
1454 };
1455 let Some(llm_db) = app.llm_db.clone() else {
1456 log::warn!("failed to retrieve LLM database");
1457 return;
1458 };
1459
1460 let executor = app.executor.clone();
1461 executor.spawn_detached({
1462 let executor = executor.clone();
1463 async move {
1464 loop {
1465 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1466 .await
1467 .context("failed to sync LLM request usage to Stripe")
1468 .trace_err();
1469 executor
1470 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1471 .await;
1472 }
1473 }
1474 });
1475}
1476
1477async fn sync_model_request_usage_with_stripe(
1478 app: &Arc<AppState>,
1479 llm_db: &Arc<LlmDatabase>,
1480 stripe_billing: &Arc<StripeBilling>,
1481) -> anyhow::Result<()> {
1482 let staff_users = app.db.get_staff_users().await?;
1483 let staff_user_ids = staff_users
1484 .iter()
1485 .map(|user| user.id)
1486 .collect::<HashSet<UserId>>();
1487
1488 let usage_meters = llm_db
1489 .get_current_subscription_usage_meters(Utc::now())
1490 .await?;
1491 let usage_meters = usage_meters
1492 .into_iter()
1493 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1494 .collect::<Vec<_>>();
1495 let user_ids = usage_meters
1496 .iter()
1497 .map(|(_, usage)| usage.user_id)
1498 .collect::<HashSet<UserId>>();
1499 let billing_subscriptions = app
1500 .db
1501 .get_active_zed_pro_billing_subscriptions(user_ids)
1502 .await?;
1503
1504 let claude_sonnet_4 = stripe_billing
1505 .find_price_by_lookup_key("claude-sonnet-4-requests")
1506 .await?;
1507 let claude_sonnet_4_max = stripe_billing
1508 .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1509 .await?;
1510 let claude_opus_4 = stripe_billing
1511 .find_price_by_lookup_key("claude-opus-4-requests")
1512 .await?;
1513 let claude_opus_4_max = stripe_billing
1514 .find_price_by_lookup_key("claude-opus-4-requests-max")
1515 .await?;
1516 let claude_3_5_sonnet = stripe_billing
1517 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1518 .await?;
1519 let claude_3_7_sonnet = stripe_billing
1520 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1521 .await?;
1522 let claude_3_7_sonnet_max = stripe_billing
1523 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1524 .await?;
1525
1526 for (usage_meter, usage) in usage_meters {
1527 maybe!(async {
1528 let Some((billing_customer, billing_subscription)) =
1529 billing_subscriptions.get(&usage.user_id)
1530 else {
1531 bail!(
1532 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1533 usage.user_id
1534 );
1535 };
1536
1537 let stripe_customer_id =
1538 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1539 let stripe_subscription_id =
1540 StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1541
1542 let model = llm_db.model_by_id(usage_meter.model_id)?;
1543
1544 let (price, meter_event_name) = match model.name.as_str() {
1545 "claude-opus-4" => match usage_meter.mode {
1546 CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1547 CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1548 },
1549 "claude-sonnet-4" => match usage_meter.mode {
1550 CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1551 CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
1552 },
1553 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1554 "claude-3-7-sonnet" => match usage_meter.mode {
1555 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1556 CompletionMode::Max => {
1557 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1558 }
1559 },
1560 model_name => {
1561 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1562 }
1563 };
1564
1565 stripe_billing
1566 .subscribe_to_price(&stripe_subscription_id, price)
1567 .await?;
1568 stripe_billing
1569 .bill_model_request_usage(
1570 &stripe_customer_id,
1571 meter_event_name,
1572 usage_meter.requests,
1573 )
1574 .await?;
1575
1576 Ok(())
1577 })
1578 .await
1579 .log_err();
1580 }
1581
1582 Ok(())
1583}