1use anyhow::{Context as _, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
21 Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
31use crate::rpc::{ResultExt as _, Server};
32use crate::{AppState, Error, Result};
33use crate::{db::UserId, llm::db::LlmDatabase};
34use crate::{
35 db::{
36 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
37 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
38 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
39 },
40 stripe_billing::StripeBilling,
41};
42
43pub fn router() -> Router {
44 Router::new()
45 .route(
46 "/billing/preferences",
47 get(get_billing_preferences).put(update_billing_preferences),
48 )
49 .route(
50 "/billing/subscriptions",
51 get(list_billing_subscriptions).post(create_billing_subscription),
52 )
53 .route(
54 "/billing/subscriptions/manage",
55 post(manage_billing_subscription),
56 )
57 .route(
58 "/billing/subscriptions/migrate",
59 post(migrate_to_new_billing),
60 )
61 .route(
62 "/billing/subscriptions/sync",
63 post(sync_billing_subscription),
64 )
65 .route("/billing/usage", get(get_current_usage))
66}
67
68#[derive(Debug, Deserialize)]
69struct GetBillingPreferencesParams {
70 github_user_id: i32,
71}
72
73#[derive(Debug, Serialize)]
74struct BillingPreferencesResponse {
75 trial_started_at: Option<String>,
76 max_monthly_llm_usage_spending_in_cents: i32,
77 model_request_overages_enabled: bool,
78 model_request_overages_spend_limit_in_cents: i32,
79}
80
81async fn get_billing_preferences(
82 Extension(app): Extension<Arc<AppState>>,
83 Query(params): Query<GetBillingPreferencesParams>,
84) -> Result<Json<BillingPreferencesResponse>> {
85 let user = app
86 .db
87 .get_user_by_github_user_id(params.github_user_id)
88 .await?
89 .context("user not found")?;
90
91 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
92 let preferences = app.db.get_billing_preferences(user.id).await?;
93
94 Ok(Json(BillingPreferencesResponse {
95 trial_started_at: billing_customer
96 .and_then(|billing_customer| billing_customer.trial_started_at)
97 .map(|trial_started_at| {
98 trial_started_at
99 .and_utc()
100 .to_rfc3339_opts(SecondsFormat::Millis, true)
101 }),
102 max_monthly_llm_usage_spending_in_cents: preferences
103 .as_ref()
104 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
105 preferences.max_monthly_llm_usage_spending_in_cents
106 }),
107 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
108 preferences.model_request_overages_enabled
109 }),
110 model_request_overages_spend_limit_in_cents: preferences
111 .as_ref()
112 .map_or(0, |preferences| {
113 preferences.model_request_overages_spend_limit_in_cents
114 }),
115 }))
116}
117
118#[derive(Debug, Deserialize)]
119struct UpdateBillingPreferencesBody {
120 github_user_id: i32,
121 #[serde(default)]
122 max_monthly_llm_usage_spending_in_cents: i32,
123 #[serde(default)]
124 model_request_overages_enabled: bool,
125 #[serde(default)]
126 model_request_overages_spend_limit_in_cents: i32,
127}
128
129async fn update_billing_preferences(
130 Extension(app): Extension<Arc<AppState>>,
131 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
132 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
133) -> Result<Json<BillingPreferencesResponse>> {
134 let user = app
135 .db
136 .get_user_by_github_user_id(body.github_user_id)
137 .await?
138 .context("user not found")?;
139
140 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
141
142 let max_monthly_llm_usage_spending_in_cents =
143 body.max_monthly_llm_usage_spending_in_cents.max(0);
144 let model_request_overages_spend_limit_in_cents =
145 body.model_request_overages_spend_limit_in_cents.max(0);
146
147 let billing_preferences =
148 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
149 app.db
150 .update_billing_preferences(
151 user.id,
152 &UpdateBillingPreferencesParams {
153 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
154 max_monthly_llm_usage_spending_in_cents,
155 ),
156 model_request_overages_enabled: ActiveValue::set(
157 body.model_request_overages_enabled,
158 ),
159 model_request_overages_spend_limit_in_cents: ActiveValue::set(
160 model_request_overages_spend_limit_in_cents,
161 ),
162 },
163 )
164 .await?
165 } else {
166 app.db
167 .create_billing_preferences(
168 user.id,
169 &crate::db::CreateBillingPreferencesParams {
170 max_monthly_llm_usage_spending_in_cents,
171 model_request_overages_enabled: body.model_request_overages_enabled,
172 model_request_overages_spend_limit_in_cents,
173 },
174 )
175 .await?
176 };
177
178 SnowflakeRow::new(
179 "Billing Preferences Updated",
180 Some(user.metrics_id),
181 user.admin,
182 None,
183 json!({
184 "user_id": user.id,
185 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
186 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
187 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
188 }),
189 )
190 .write(&app.kinesis_client, &app.config.kinesis_stream)
191 .await
192 .log_err();
193
194 rpc_server.refresh_llm_tokens_for_user(user.id).await;
195
196 Ok(Json(BillingPreferencesResponse {
197 trial_started_at: billing_customer
198 .and_then(|billing_customer| billing_customer.trial_started_at)
199 .map(|trial_started_at| {
200 trial_started_at
201 .and_utc()
202 .to_rfc3339_opts(SecondsFormat::Millis, true)
203 }),
204 max_monthly_llm_usage_spending_in_cents: billing_preferences
205 .max_monthly_llm_usage_spending_in_cents,
206 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
207 model_request_overages_spend_limit_in_cents: billing_preferences
208 .model_request_overages_spend_limit_in_cents,
209 }))
210}
211
212#[derive(Debug, Deserialize)]
213struct ListBillingSubscriptionsParams {
214 github_user_id: i32,
215}
216
217#[derive(Debug, Serialize)]
218struct BillingSubscriptionJson {
219 id: BillingSubscriptionId,
220 name: String,
221 status: StripeSubscriptionStatus,
222 trial_end_at: Option<String>,
223 cancel_at: Option<String>,
224 /// Whether this subscription can be canceled.
225 is_cancelable: bool,
226}
227
228#[derive(Debug, Serialize)]
229struct ListBillingSubscriptionsResponse {
230 subscriptions: Vec<BillingSubscriptionJson>,
231}
232
233async fn list_billing_subscriptions(
234 Extension(app): Extension<Arc<AppState>>,
235 Query(params): Query<ListBillingSubscriptionsParams>,
236) -> Result<Json<ListBillingSubscriptionsResponse>> {
237 let user = app
238 .db
239 .get_user_by_github_user_id(params.github_user_id)
240 .await?
241 .context("user not found")?;
242
243 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
244
245 Ok(Json(ListBillingSubscriptionsResponse {
246 subscriptions: subscriptions
247 .into_iter()
248 .map(|subscription| BillingSubscriptionJson {
249 id: subscription.id,
250 name: match subscription.kind {
251 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
252 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
253 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
254 None => "Zed LLM Usage".to_string(),
255 },
256 status: subscription.stripe_subscription_status,
257 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
258 maybe!({
259 let end_at = subscription.stripe_current_period_end?;
260 let end_at = DateTime::from_timestamp(end_at, 0)?;
261
262 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
263 })
264 } else {
265 None
266 },
267 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
268 cancel_at
269 .and_utc()
270 .to_rfc3339_opts(SecondsFormat::Millis, true)
271 }),
272 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
273 && subscription.stripe_cancel_at.is_none(),
274 })
275 .collect(),
276 }))
277}
278
279#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
280#[serde(rename_all = "snake_case")]
281enum ProductCode {
282 ZedPro,
283 ZedProTrial,
284 ZedFree,
285}
286
287#[derive(Debug, Deserialize)]
288struct CreateBillingSubscriptionBody {
289 github_user_id: i32,
290 product: ProductCode,
291}
292
293#[derive(Debug, Serialize)]
294struct CreateBillingSubscriptionResponse {
295 checkout_session_url: String,
296}
297
298/// Initiates a Stripe Checkout session for creating a billing subscription.
299async fn create_billing_subscription(
300 Extension(app): Extension<Arc<AppState>>,
301 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
302) -> Result<Json<CreateBillingSubscriptionResponse>> {
303 let user = app
304 .db
305 .get_user_by_github_user_id(body.github_user_id)
306 .await?
307 .context("user not found")?;
308
309 let Some(stripe_billing) = app.stripe_billing.clone() else {
310 log::error!("failed to retrieve Stripe billing object");
311 Err(Error::http(
312 StatusCode::NOT_IMPLEMENTED,
313 "not supported".into(),
314 ))?
315 };
316
317 if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
318 let is_checkout_allowed = body.product == ProductCode::ZedProTrial
319 && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
320
321 if !is_checkout_allowed {
322 return Err(Error::http(
323 StatusCode::CONFLICT,
324 "user already has an active subscription".into(),
325 ));
326 }
327 }
328
329 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
330 if let Some(existing_billing_customer) = &existing_billing_customer {
331 if existing_billing_customer.has_overdue_invoices {
332 return Err(Error::http(
333 StatusCode::PAYMENT_REQUIRED,
334 "user has overdue invoices".into(),
335 ));
336 }
337 }
338
339 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
340 CustomerId::from_str(&existing_customer.stripe_customer_id)
341 .context("failed to parse customer ID")?
342 } else {
343 stripe_billing
344 .find_or_create_customer_by_email(user.email_address.as_deref())
345 .await?
346 };
347
348 let success_url = format!(
349 "{}/account?checkout_complete=1",
350 app.config.zed_dot_dev_url()
351 );
352
353 let checkout_session_url = match body.product {
354 ProductCode::ZedPro => {
355 stripe_billing
356 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
357 .await?
358 }
359 ProductCode::ZedProTrial => {
360 if let Some(existing_billing_customer) = &existing_billing_customer {
361 if existing_billing_customer.trial_started_at.is_some() {
362 return Err(Error::http(
363 StatusCode::FORBIDDEN,
364 "user already used free trial".into(),
365 ));
366 }
367 }
368
369 let feature_flags = app.db.get_user_flags(user.id).await?;
370
371 stripe_billing
372 .checkout_with_zed_pro_trial(
373 customer_id,
374 &user.github_login,
375 feature_flags,
376 &success_url,
377 )
378 .await?
379 }
380 ProductCode::ZedFree => {
381 stripe_billing
382 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
383 .await?
384 }
385 };
386
387 Ok(Json(CreateBillingSubscriptionResponse {
388 checkout_session_url,
389 }))
390}
391
392#[derive(Debug, PartialEq, Deserialize)]
393#[serde(rename_all = "snake_case")]
394enum ManageSubscriptionIntent {
395 /// The user intends to manage their subscription.
396 ///
397 /// This will open the Stripe billing portal without putting the user in a specific flow.
398 ManageSubscription,
399 /// The user intends to update their payment method.
400 UpdatePaymentMethod,
401 /// The user intends to upgrade to Zed Pro.
402 UpgradeToPro,
403 /// The user intends to cancel their subscription.
404 Cancel,
405 /// The user intends to stop the cancellation of their subscription.
406 StopCancellation,
407}
408
409#[derive(Debug, Deserialize)]
410struct ManageBillingSubscriptionBody {
411 github_user_id: i32,
412 intent: ManageSubscriptionIntent,
413 /// The ID of the subscription to manage.
414 subscription_id: BillingSubscriptionId,
415 redirect_to: Option<String>,
416}
417
418#[derive(Debug, Serialize)]
419struct ManageBillingSubscriptionResponse {
420 billing_portal_session_url: Option<String>,
421}
422
423/// Initiates a Stripe customer portal session for managing a billing subscription.
424async fn manage_billing_subscription(
425 Extension(app): Extension<Arc<AppState>>,
426 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
427) -> Result<Json<ManageBillingSubscriptionResponse>> {
428 let user = app
429 .db
430 .get_user_by_github_user_id(body.github_user_id)
431 .await?
432 .context("user not found")?;
433
434 let Some(stripe_client) = app.stripe_client.clone() else {
435 log::error!("failed to retrieve Stripe client");
436 Err(Error::http(
437 StatusCode::NOT_IMPLEMENTED,
438 "not supported".into(),
439 ))?
440 };
441
442 let Some(stripe_billing) = app.stripe_billing.clone() else {
443 log::error!("failed to retrieve Stripe billing object");
444 Err(Error::http(
445 StatusCode::NOT_IMPLEMENTED,
446 "not supported".into(),
447 ))?
448 };
449
450 let customer = app
451 .db
452 .get_billing_customer_by_user_id(user.id)
453 .await?
454 .context("billing customer not found")?;
455 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
456 .context("failed to parse customer ID")?;
457
458 let subscription = app
459 .db
460 .get_billing_subscription_by_id(body.subscription_id)
461 .await?
462 .context("subscription not found")?;
463 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
464 .context("failed to parse subscription ID")?;
465
466 if body.intent == ManageSubscriptionIntent::StopCancellation {
467 let updated_stripe_subscription = Subscription::update(
468 &stripe_client,
469 &subscription_id,
470 stripe::UpdateSubscription {
471 cancel_at_period_end: Some(false),
472 ..Default::default()
473 },
474 )
475 .await?;
476
477 app.db
478 .update_billing_subscription(
479 subscription.id,
480 &UpdateBillingSubscriptionParams {
481 stripe_cancel_at: ActiveValue::set(
482 updated_stripe_subscription
483 .cancel_at
484 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
485 .map(|time| time.naive_utc()),
486 ),
487 ..Default::default()
488 },
489 )
490 .await?;
491
492 return Ok(Json(ManageBillingSubscriptionResponse {
493 billing_portal_session_url: None,
494 }));
495 }
496
497 let flow = match body.intent {
498 ManageSubscriptionIntent::ManageSubscription => None,
499 ManageSubscriptionIntent::UpgradeToPro => {
500 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
501 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
502
503 let stripe_subscription =
504 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
505
506 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
507 && stripe_subscription.items.data.iter().any(|item| {
508 item.price
509 .as_ref()
510 .map_or(false, |price| price.id == zed_pro_price_id)
511 });
512 if is_on_zed_pro_trial {
513 let payment_methods = PaymentMethod::list(
514 &stripe_client,
515 &stripe::ListPaymentMethods {
516 customer: Some(stripe_subscription.customer.id()),
517 ..Default::default()
518 },
519 )
520 .await?;
521
522 let has_payment_method = !payment_methods.data.is_empty();
523 if !has_payment_method {
524 return Err(Error::http(
525 StatusCode::BAD_REQUEST,
526 "missing payment method".into(),
527 ));
528 }
529
530 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
531 Subscription::update(
532 &stripe_client,
533 &stripe_subscription.id,
534 stripe::UpdateSubscription {
535 trial_end: Some(stripe::Scheduled::now()),
536 ..Default::default()
537 },
538 )
539 .await?;
540
541 return Ok(Json(ManageBillingSubscriptionResponse {
542 billing_portal_session_url: None,
543 }));
544 }
545
546 let subscription_item_to_update = stripe_subscription
547 .items
548 .data
549 .iter()
550 .find_map(|item| {
551 let price = item.price.as_ref()?;
552
553 if price.id == zed_free_price_id {
554 Some(item.id.clone())
555 } else {
556 None
557 }
558 })
559 .context("No subscription item to update")?;
560
561 Some(CreateBillingPortalSessionFlowData {
562 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
563 subscription_update_confirm: Some(
564 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
565 subscription: subscription.stripe_subscription_id,
566 items: vec![
567 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
568 id: subscription_item_to_update.to_string(),
569 price: Some(zed_pro_price_id.to_string()),
570 quantity: Some(1),
571 },
572 ],
573 discounts: None,
574 },
575 ),
576 ..Default::default()
577 })
578 }
579 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
580 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
581 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
582 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
583 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
584 return_url: format!(
585 "{}{path}",
586 app.config.zed_dot_dev_url(),
587 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
588 ),
589 }),
590 ..Default::default()
591 }),
592 ..Default::default()
593 }),
594 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
595 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
596 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
597 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
598 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
599 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
600 }),
601 ..Default::default()
602 }),
603 subscription_cancel: Some(
604 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
605 subscription: subscription.stripe_subscription_id,
606 retention: None,
607 },
608 ),
609 ..Default::default()
610 }),
611 ManageSubscriptionIntent::StopCancellation => unreachable!(),
612 };
613
614 let mut params = CreateBillingPortalSession::new(customer_id);
615 params.flow_data = flow;
616 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
617 params.return_url = Some(&return_url);
618
619 let session = BillingPortalSession::create(&stripe_client, params).await?;
620
621 Ok(Json(ManageBillingSubscriptionResponse {
622 billing_portal_session_url: Some(session.url),
623 }))
624}
625
626#[derive(Debug, Deserialize)]
627struct MigrateToNewBillingBody {
628 github_user_id: i32,
629}
630
631#[derive(Debug, Serialize)]
632struct MigrateToNewBillingResponse {
633 /// The ID of the subscription that was canceled.
634 canceled_subscription_id: Option<String>,
635}
636
637async fn migrate_to_new_billing(
638 Extension(app): Extension<Arc<AppState>>,
639 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
640) -> Result<Json<MigrateToNewBillingResponse>> {
641 let Some(stripe_client) = app.stripe_client.clone() else {
642 log::error!("failed to retrieve Stripe client");
643 Err(Error::http(
644 StatusCode::NOT_IMPLEMENTED,
645 "not supported".into(),
646 ))?
647 };
648
649 let user = app
650 .db
651 .get_user_by_github_user_id(body.github_user_id)
652 .await?
653 .context("user not found")?;
654
655 let old_billing_subscriptions_by_user = app
656 .db
657 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
658 .await?;
659
660 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
661 old_billing_subscriptions_by_user.get(&user.id)
662 {
663 let stripe_subscription_id = billing_subscription
664 .stripe_subscription_id
665 .parse::<stripe::SubscriptionId>()
666 .context("failed to parse Stripe subscription ID from database")?;
667
668 Subscription::cancel(
669 &stripe_client,
670 &stripe_subscription_id,
671 stripe::CancelSubscription {
672 invoice_now: Some(true),
673 ..Default::default()
674 },
675 )
676 .await?;
677
678 Some(stripe_subscription_id)
679 } else {
680 None
681 };
682
683 let all_feature_flags = app.db.list_feature_flags().await?;
684 let user_feature_flags = app.db.get_user_flags(user.id).await?;
685
686 for feature_flag in ["new-billing", "assistant2"] {
687 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
688 if already_in_feature_flag {
689 continue;
690 }
691
692 let feature_flag = all_feature_flags
693 .iter()
694 .find(|flag| flag.flag == feature_flag)
695 .context("failed to find feature flag: {feature_flag:?}")?;
696
697 app.db.add_user_flag(user.id, feature_flag.id).await?;
698 }
699
700 Ok(Json(MigrateToNewBillingResponse {
701 canceled_subscription_id: canceled_subscription_id
702 .map(|subscription_id| subscription_id.to_string()),
703 }))
704}
705
706#[derive(Debug, Deserialize)]
707struct SyncBillingSubscriptionBody {
708 github_user_id: i32,
709}
710
711#[derive(Debug, Serialize)]
712struct SyncBillingSubscriptionResponse {
713 stripe_customer_id: String,
714}
715
716async fn sync_billing_subscription(
717 Extension(app): Extension<Arc<AppState>>,
718 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
719) -> Result<Json<SyncBillingSubscriptionResponse>> {
720 let Some(stripe_client) = app.stripe_client.clone() else {
721 log::error!("failed to retrieve Stripe client");
722 Err(Error::http(
723 StatusCode::NOT_IMPLEMENTED,
724 "not supported".into(),
725 ))?
726 };
727
728 let user = app
729 .db
730 .get_user_by_github_user_id(body.github_user_id)
731 .await?
732 .context("user not found")?;
733
734 let billing_customer = app
735 .db
736 .get_billing_customer_by_user_id(user.id)
737 .await?
738 .context("billing customer not found")?;
739 let stripe_customer_id = billing_customer
740 .stripe_customer_id
741 .parse::<stripe::CustomerId>()
742 .context("failed to parse Stripe customer ID from database")?;
743
744 let subscriptions = Subscription::list(
745 &stripe_client,
746 &stripe::ListSubscriptions {
747 customer: Some(stripe_customer_id),
748 // Sync all non-canceled subscriptions.
749 status: None,
750 ..Default::default()
751 },
752 )
753 .await?;
754
755 for subscription in subscriptions.data {
756 let subscription_id = subscription.id.clone();
757
758 sync_subscription(&app, &stripe_client, subscription)
759 .await
760 .with_context(|| {
761 format!(
762 "failed to sync subscription {subscription_id} for user {}",
763 user.id,
764 )
765 })?;
766 }
767
768 Ok(Json(SyncBillingSubscriptionResponse {
769 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
770 }))
771}
772
773/// The amount of time we wait in between each poll of Stripe events.
774///
775/// This value should strike a balance between:
776/// 1. Being short enough that we update quickly when something in Stripe changes
777/// 2. Being long enough that we don't eat into our rate limits.
778///
779/// As a point of reference, the Sequin folks say they have this at **500ms**:
780///
781/// > We poll the Stripe /events endpoint every 500ms per account
782/// >
783/// > — https://blog.sequinstream.com/events-not-webhooks/
784const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
785
786/// The maximum number of events to return per page.
787///
788/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
789///
790/// > Limit can range between 1 and 100, and the default is 10.
791const EVENTS_LIMIT_PER_PAGE: u64 = 100;
792
793/// The number of pages consisting entirely of already-processed events that we
794/// will see before we stop retrieving events.
795///
796/// This is used to prevent over-fetching the Stripe events API for events we've
797/// already seen and processed.
798const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
799
800/// Polls the Stripe events API periodically to reconcile the records in our
801/// database with the data in Stripe.
802pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
803 let Some(stripe_client) = app.stripe_client.clone() else {
804 log::warn!("failed to retrieve Stripe client");
805 return;
806 };
807
808 let executor = app.executor.clone();
809 executor.spawn_detached({
810 let executor = executor.clone();
811 async move {
812 loop {
813 poll_stripe_events(&app, &rpc_server, &stripe_client)
814 .await
815 .log_err();
816
817 executor.sleep(POLL_EVENTS_INTERVAL).await;
818 }
819 }
820 });
821}
822
823async fn poll_stripe_events(
824 app: &Arc<AppState>,
825 rpc_server: &Arc<Server>,
826 stripe_client: &stripe::Client,
827) -> anyhow::Result<()> {
828 fn event_type_to_string(event_type: EventType) -> String {
829 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
830 // so we need to unquote it.
831 event_type.to_string().trim_matches('"').to_string()
832 }
833
834 let event_types = [
835 EventType::CustomerCreated,
836 EventType::CustomerUpdated,
837 EventType::CustomerSubscriptionCreated,
838 EventType::CustomerSubscriptionUpdated,
839 EventType::CustomerSubscriptionPaused,
840 EventType::CustomerSubscriptionResumed,
841 EventType::CustomerSubscriptionDeleted,
842 ]
843 .into_iter()
844 .map(event_type_to_string)
845 .collect::<Vec<_>>();
846
847 let mut pages_of_already_processed_events = 0;
848 let mut unprocessed_events = Vec::new();
849
850 log::info!(
851 "Stripe events: starting retrieval for {}",
852 event_types.join(", ")
853 );
854 let mut params = ListEvents::new();
855 params.types = Some(event_types.clone());
856 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
857
858 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
859 .await?
860 .paginate(params);
861
862 loop {
863 let processed_event_ids = {
864 let event_ids = event_pages
865 .page
866 .data
867 .iter()
868 .map(|event| event.id.as_str())
869 .collect::<Vec<_>>();
870 app.db
871 .get_processed_stripe_events_by_event_ids(&event_ids)
872 .await?
873 .into_iter()
874 .map(|event| event.stripe_event_id)
875 .collect::<Vec<_>>()
876 };
877
878 let mut processed_events_in_page = 0;
879 let events_in_page = event_pages.page.data.len();
880 for event in &event_pages.page.data {
881 if processed_event_ids.contains(&event.id.to_string()) {
882 processed_events_in_page += 1;
883 log::debug!("Stripe events: already processed '{}', skipping", event.id);
884 } else {
885 unprocessed_events.push(event.clone());
886 }
887 }
888
889 if processed_events_in_page == events_in_page {
890 pages_of_already_processed_events += 1;
891 }
892
893 if event_pages.page.has_more {
894 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
895 {
896 log::info!(
897 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
898 );
899 break;
900 } else {
901 log::info!("Stripe events: retrieving next page");
902 event_pages = event_pages.next(&stripe_client).await?;
903 }
904 } else {
905 break;
906 }
907 }
908
909 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
910
911 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
912 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
913
914 for event in unprocessed_events {
915 let event_id = event.id.clone();
916 let processed_event_params = CreateProcessedStripeEventParams {
917 stripe_event_id: event.id.to_string(),
918 stripe_event_type: event_type_to_string(event.type_),
919 stripe_event_created_timestamp: event.created,
920 };
921
922 // If the event has happened too far in the past, we don't want to
923 // process it and risk overwriting other more-recent updates.
924 //
925 // 1 day was chosen arbitrarily. This could be made longer or shorter.
926 let one_day = Duration::from_secs(24 * 60 * 60);
927 let a_day_ago = Utc::now() - one_day;
928 if a_day_ago.timestamp() > event.created {
929 log::info!(
930 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
931 event_id
932 );
933 app.db
934 .create_processed_stripe_event(&processed_event_params)
935 .await?;
936
937 continue;
938 }
939
940 let process_result = match event.type_ {
941 EventType::CustomerCreated | EventType::CustomerUpdated => {
942 handle_customer_event(app, stripe_client, event).await
943 }
944 EventType::CustomerSubscriptionCreated
945 | EventType::CustomerSubscriptionUpdated
946 | EventType::CustomerSubscriptionPaused
947 | EventType::CustomerSubscriptionResumed
948 | EventType::CustomerSubscriptionDeleted => {
949 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
950 }
951 _ => Ok(()),
952 };
953
954 if let Some(()) = process_result
955 .with_context(|| format!("failed to process event {event_id} successfully"))
956 .log_err()
957 {
958 app.db
959 .create_processed_stripe_event(&processed_event_params)
960 .await?;
961 }
962 }
963
964 Ok(())
965}
966
967async fn handle_customer_event(
968 app: &Arc<AppState>,
969 _stripe_client: &stripe::Client,
970 event: stripe::Event,
971) -> anyhow::Result<()> {
972 let EventObject::Customer(customer) = event.data.object else {
973 bail!("unexpected event payload for {}", event.id);
974 };
975
976 log::info!("handling Stripe {} event: {}", event.type_, event.id);
977
978 let Some(email) = customer.email else {
979 log::info!("Stripe customer has no email: skipping");
980 return Ok(());
981 };
982
983 let Some(user) = app.db.get_user_by_email(&email).await? else {
984 log::info!("no user found for email: skipping");
985 return Ok(());
986 };
987
988 if let Some(existing_customer) = app
989 .db
990 .get_billing_customer_by_stripe_customer_id(&customer.id)
991 .await?
992 {
993 app.db
994 .update_billing_customer(
995 existing_customer.id,
996 &UpdateBillingCustomerParams {
997 // For now we just leave the information as-is, as it is not
998 // likely to change.
999 ..Default::default()
1000 },
1001 )
1002 .await?;
1003 } else {
1004 app.db
1005 .create_billing_customer(&CreateBillingCustomerParams {
1006 user_id: user.id,
1007 stripe_customer_id: customer.id.to_string(),
1008 })
1009 .await?;
1010 }
1011
1012 Ok(())
1013}
1014
1015async fn sync_subscription(
1016 app: &Arc<AppState>,
1017 stripe_client: &stripe::Client,
1018 subscription: stripe::Subscription,
1019) -> anyhow::Result<billing_customer::Model> {
1020 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1021 stripe_billing
1022 .determine_subscription_kind(&subscription)
1023 .await
1024 } else {
1025 None
1026 };
1027
1028 let billing_customer =
1029 find_or_create_billing_customer(app, stripe_client, subscription.customer)
1030 .await?
1031 .context("billing customer not found")?;
1032
1033 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1034 if subscription.status == SubscriptionStatus::Trialing {
1035 let current_period_start =
1036 DateTime::from_timestamp(subscription.current_period_start, 0)
1037 .context("No trial subscription period start")?;
1038
1039 app.db
1040 .update_billing_customer(
1041 billing_customer.id,
1042 &UpdateBillingCustomerParams {
1043 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1044 ..Default::default()
1045 },
1046 )
1047 .await?;
1048 }
1049 }
1050
1051 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1052 && subscription
1053 .cancellation_details
1054 .as_ref()
1055 .and_then(|details| details.reason)
1056 .map_or(false, |reason| {
1057 reason == CancellationDetailsReason::PaymentFailed
1058 });
1059
1060 if was_canceled_due_to_payment_failure {
1061 app.db
1062 .update_billing_customer(
1063 billing_customer.id,
1064 &UpdateBillingCustomerParams {
1065 has_overdue_invoices: ActiveValue::set(true),
1066 ..Default::default()
1067 },
1068 )
1069 .await?;
1070 }
1071
1072 if let Some(existing_subscription) = app
1073 .db
1074 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1075 .await?
1076 {
1077 app.db
1078 .update_billing_subscription(
1079 existing_subscription.id,
1080 &UpdateBillingSubscriptionParams {
1081 billing_customer_id: ActiveValue::set(billing_customer.id),
1082 kind: ActiveValue::set(subscription_kind),
1083 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1084 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1085 stripe_cancel_at: ActiveValue::set(
1086 subscription
1087 .cancel_at
1088 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1089 .map(|time| time.naive_utc()),
1090 ),
1091 stripe_cancellation_reason: ActiveValue::set(
1092 subscription
1093 .cancellation_details
1094 .and_then(|details| details.reason)
1095 .map(|reason| reason.into()),
1096 ),
1097 stripe_current_period_start: ActiveValue::set(Some(
1098 subscription.current_period_start,
1099 )),
1100 stripe_current_period_end: ActiveValue::set(Some(
1101 subscription.current_period_end,
1102 )),
1103 },
1104 )
1105 .await?;
1106 } else {
1107 if let Some(existing_subscription) = app
1108 .db
1109 .get_active_billing_subscription(billing_customer.user_id)
1110 .await?
1111 {
1112 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1113 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1114 {
1115 let stripe_subscription_id = existing_subscription
1116 .stripe_subscription_id
1117 .parse::<stripe::SubscriptionId>()
1118 .context("failed to parse Stripe subscription ID from database")?;
1119
1120 Subscription::cancel(
1121 &stripe_client,
1122 &stripe_subscription_id,
1123 stripe::CancelSubscription {
1124 invoice_now: None,
1125 ..Default::default()
1126 },
1127 )
1128 .await?;
1129 } else {
1130 // If the user already has an active billing subscription, ignore the
1131 // event and return an `Ok` to signal that it was processed
1132 // successfully.
1133 //
1134 // There is the possibility that this could cause us to not create a
1135 // subscription in the following scenario:
1136 //
1137 // 1. User has an active subscription A
1138 // 2. User cancels subscription A
1139 // 3. User creates a new subscription B
1140 // 4. We process the new subscription B before the cancellation of subscription A
1141 // 5. User ends up with no subscriptions
1142 //
1143 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1144
1145 log::info!(
1146 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1147 user_id = billing_customer.user_id,
1148 subscription_id = subscription.id
1149 );
1150 return Ok(billing_customer);
1151 }
1152 }
1153
1154 app.db
1155 .create_billing_subscription(&CreateBillingSubscriptionParams {
1156 billing_customer_id: billing_customer.id,
1157 kind: subscription_kind,
1158 stripe_subscription_id: subscription.id.to_string(),
1159 stripe_subscription_status: subscription.status.into(),
1160 stripe_cancellation_reason: subscription
1161 .cancellation_details
1162 .and_then(|details| details.reason)
1163 .map(|reason| reason.into()),
1164 stripe_current_period_start: Some(subscription.current_period_start),
1165 stripe_current_period_end: Some(subscription.current_period_end),
1166 })
1167 .await?;
1168 }
1169
1170 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1171 if subscription.status == SubscriptionStatus::Canceled
1172 || subscription.status == SubscriptionStatus::Paused
1173 {
1174 let already_has_active_billing_subscription = app
1175 .db
1176 .has_active_billing_subscription(billing_customer.user_id)
1177 .await?;
1178 if !already_has_active_billing_subscription {
1179 let stripe_customer_id = billing_customer
1180 .stripe_customer_id
1181 .parse::<stripe::CustomerId>()
1182 .context("failed to parse Stripe customer ID from database")?;
1183
1184 stripe_billing
1185 .subscribe_to_zed_free(stripe_customer_id)
1186 .await?;
1187 }
1188 }
1189 }
1190
1191 Ok(billing_customer)
1192}
1193
1194async fn handle_customer_subscription_event(
1195 app: &Arc<AppState>,
1196 rpc_server: &Arc<Server>,
1197 stripe_client: &stripe::Client,
1198 event: stripe::Event,
1199) -> anyhow::Result<()> {
1200 let EventObject::Subscription(subscription) = event.data.object else {
1201 bail!("unexpected event payload for {}", event.id);
1202 };
1203
1204 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1205
1206 let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1207
1208 // When the user's subscription changes, push down any changes to their plan.
1209 rpc_server
1210 .update_plan_for_user(billing_customer.user_id)
1211 .await
1212 .trace_err();
1213
1214 // When the user's subscription changes, we want to refresh their LLM tokens
1215 // to either grant/revoke access.
1216 rpc_server
1217 .refresh_llm_tokens_for_user(billing_customer.user_id)
1218 .await;
1219
1220 Ok(())
1221}
1222
1223#[derive(Debug, Deserialize)]
1224struct GetCurrentUsageParams {
1225 github_user_id: i32,
1226}
1227
1228#[derive(Debug, Serialize)]
1229struct UsageCounts {
1230 pub used: i32,
1231 pub limit: Option<i32>,
1232 pub remaining: Option<i32>,
1233}
1234
1235#[derive(Debug, Serialize)]
1236struct ModelRequestUsage {
1237 pub model: String,
1238 pub mode: CompletionMode,
1239 pub requests: i32,
1240}
1241
1242#[derive(Debug, Serialize)]
1243struct CurrentUsage {
1244 pub model_requests: UsageCounts,
1245 pub model_request_usage: Vec<ModelRequestUsage>,
1246 pub edit_predictions: UsageCounts,
1247}
1248
1249#[derive(Debug, Default, Serialize)]
1250struct GetCurrentUsageResponse {
1251 pub plan: String,
1252 pub current_usage: Option<CurrentUsage>,
1253}
1254
1255async fn get_current_usage(
1256 Extension(app): Extension<Arc<AppState>>,
1257 Query(params): Query<GetCurrentUsageParams>,
1258) -> Result<Json<GetCurrentUsageResponse>> {
1259 let user = app
1260 .db
1261 .get_user_by_github_user_id(params.github_user_id)
1262 .await?
1263 .context("user not found")?;
1264
1265 let feature_flags = app.db.get_user_flags(user.id).await?;
1266 let has_extended_trial = feature_flags
1267 .iter()
1268 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1269
1270 let Some(llm_db) = app.llm_db.clone() else {
1271 return Err(Error::http(
1272 StatusCode::NOT_IMPLEMENTED,
1273 "LLM database not available".into(),
1274 ));
1275 };
1276
1277 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1278 return Ok(Json(GetCurrentUsageResponse::default()));
1279 };
1280
1281 let subscription_period = maybe!({
1282 let period_start_at = subscription.current_period_start_at()?;
1283 let period_end_at = subscription.current_period_end_at()?;
1284
1285 Some((period_start_at, period_end_at))
1286 });
1287
1288 let Some((period_start_at, period_end_at)) = subscription_period else {
1289 return Ok(Json(GetCurrentUsageResponse::default()));
1290 };
1291
1292 let usage = llm_db
1293 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1294 .await?;
1295
1296 let plan = subscription
1297 .kind
1298 .map(Into::into)
1299 .unwrap_or(zed_llm_client::Plan::ZedFree);
1300
1301 let model_requests_limit = match plan.model_requests_limit() {
1302 zed_llm_client::UsageLimit::Limited(limit) => {
1303 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1304 1_000
1305 } else {
1306 limit
1307 };
1308
1309 Some(limit)
1310 }
1311 zed_llm_client::UsageLimit::Unlimited => None,
1312 };
1313
1314 let edit_predictions_limit = match plan.edit_predictions_limit() {
1315 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1316 zed_llm_client::UsageLimit::Unlimited => None,
1317 };
1318
1319 let Some(usage) = usage else {
1320 return Ok(Json(GetCurrentUsageResponse {
1321 plan: plan.as_str().to_string(),
1322 current_usage: Some(CurrentUsage {
1323 model_requests: UsageCounts {
1324 used: 0,
1325 limit: model_requests_limit,
1326 remaining: model_requests_limit,
1327 },
1328 model_request_usage: Vec::new(),
1329 edit_predictions: UsageCounts {
1330 used: 0,
1331 limit: edit_predictions_limit,
1332 remaining: edit_predictions_limit,
1333 },
1334 }),
1335 }));
1336 };
1337
1338 let subscription_usage_meters = llm_db
1339 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1340 .await?;
1341
1342 let model_request_usage = subscription_usage_meters
1343 .into_iter()
1344 .filter_map(|(usage_meter, _usage)| {
1345 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1346
1347 Some(ModelRequestUsage {
1348 model: model.name.clone(),
1349 mode: usage_meter.mode,
1350 requests: usage_meter.requests,
1351 })
1352 })
1353 .collect::<Vec<_>>();
1354
1355 Ok(Json(GetCurrentUsageResponse {
1356 plan: plan.as_str().to_string(),
1357 current_usage: Some(CurrentUsage {
1358 model_requests: UsageCounts {
1359 used: usage.model_requests,
1360 limit: model_requests_limit,
1361 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1362 },
1363 model_request_usage,
1364 edit_predictions: UsageCounts {
1365 used: usage.edit_predictions,
1366 limit: edit_predictions_limit,
1367 remaining: edit_predictions_limit
1368 .map(|limit| (limit - usage.edit_predictions).max(0)),
1369 },
1370 }),
1371 }))
1372}
1373
1374impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1375 fn from(value: SubscriptionStatus) -> Self {
1376 match value {
1377 SubscriptionStatus::Incomplete => Self::Incomplete,
1378 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1379 SubscriptionStatus::Trialing => Self::Trialing,
1380 SubscriptionStatus::Active => Self::Active,
1381 SubscriptionStatus::PastDue => Self::PastDue,
1382 SubscriptionStatus::Canceled => Self::Canceled,
1383 SubscriptionStatus::Unpaid => Self::Unpaid,
1384 SubscriptionStatus::Paused => Self::Paused,
1385 }
1386 }
1387}
1388
1389impl From<CancellationDetailsReason> for StripeCancellationReason {
1390 fn from(value: CancellationDetailsReason) -> Self {
1391 match value {
1392 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1393 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1394 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1395 }
1396 }
1397}
1398
1399/// Finds or creates a billing customer using the provided customer.
1400pub async fn find_or_create_billing_customer(
1401 app: &Arc<AppState>,
1402 stripe_client: &stripe::Client,
1403 customer_or_id: Expandable<Customer>,
1404) -> anyhow::Result<Option<billing_customer::Model>> {
1405 let customer_id = match &customer_or_id {
1406 Expandable::Id(id) => id,
1407 Expandable::Object(customer) => customer.id.as_ref(),
1408 };
1409
1410 // If we already have a billing customer record associated with the Stripe customer,
1411 // there's nothing more we need to do.
1412 if let Some(billing_customer) = app
1413 .db
1414 .get_billing_customer_by_stripe_customer_id(customer_id)
1415 .await?
1416 {
1417 return Ok(Some(billing_customer));
1418 }
1419
1420 // If all we have is a customer ID, resolve it to a full customer record by
1421 // hitting the Stripe API.
1422 let customer = match customer_or_id {
1423 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1424 Expandable::Object(customer) => *customer,
1425 };
1426
1427 let Some(email) = customer.email else {
1428 return Ok(None);
1429 };
1430
1431 let Some(user) = app.db.get_user_by_email(&email).await? else {
1432 return Ok(None);
1433 };
1434
1435 let billing_customer = app
1436 .db
1437 .create_billing_customer(&CreateBillingCustomerParams {
1438 user_id: user.id,
1439 stripe_customer_id: customer.id.to_string(),
1440 })
1441 .await?;
1442
1443 Ok(Some(billing_customer))
1444}
1445
1446const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1447
1448pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1449 let Some(stripe_billing) = app.stripe_billing.clone() else {
1450 log::warn!("failed to retrieve Stripe billing object");
1451 return;
1452 };
1453 let Some(llm_db) = app.llm_db.clone() else {
1454 log::warn!("failed to retrieve LLM database");
1455 return;
1456 };
1457
1458 let executor = app.executor.clone();
1459 executor.spawn_detached({
1460 let executor = executor.clone();
1461 async move {
1462 loop {
1463 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1464 .await
1465 .context("failed to sync LLM request usage to Stripe")
1466 .trace_err();
1467 executor
1468 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1469 .await;
1470 }
1471 }
1472 });
1473}
1474
1475async fn sync_model_request_usage_with_stripe(
1476 app: &Arc<AppState>,
1477 llm_db: &Arc<LlmDatabase>,
1478 stripe_billing: &Arc<StripeBilling>,
1479) -> anyhow::Result<()> {
1480 let staff_users = app.db.get_staff_users().await?;
1481 let staff_user_ids = staff_users
1482 .iter()
1483 .map(|user| user.id)
1484 .collect::<HashSet<UserId>>();
1485
1486 let usage_meters = llm_db
1487 .get_current_subscription_usage_meters(Utc::now())
1488 .await?;
1489 let usage_meters = usage_meters
1490 .into_iter()
1491 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1492 .collect::<Vec<_>>();
1493 let user_ids = usage_meters
1494 .iter()
1495 .map(|(_, usage)| usage.user_id)
1496 .collect::<HashSet<UserId>>();
1497 let billing_subscriptions = app
1498 .db
1499 .get_active_zed_pro_billing_subscriptions(user_ids)
1500 .await?;
1501
1502 let claude_3_5_sonnet = stripe_billing
1503 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1504 .await?;
1505 let claude_3_7_sonnet = stripe_billing
1506 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1507 .await?;
1508 let claude_3_7_sonnet_max = stripe_billing
1509 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1510 .await?;
1511
1512 for (usage_meter, usage) in usage_meters {
1513 maybe!(async {
1514 let Some((billing_customer, billing_subscription)) =
1515 billing_subscriptions.get(&usage.user_id)
1516 else {
1517 bail!(
1518 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1519 usage.user_id
1520 );
1521 };
1522
1523 let stripe_customer_id = billing_customer
1524 .stripe_customer_id
1525 .parse::<stripe::CustomerId>()
1526 .context("failed to parse Stripe customer ID from database")?;
1527 let stripe_subscription_id = billing_subscription
1528 .stripe_subscription_id
1529 .parse::<stripe::SubscriptionId>()
1530 .context("failed to parse Stripe subscription ID from database")?;
1531
1532 let model = llm_db.model_by_id(usage_meter.model_id)?;
1533
1534 let (price, meter_event_name) = match model.name.as_str() {
1535 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1536 "claude-3-7-sonnet" => match usage_meter.mode {
1537 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1538 CompletionMode::Max => {
1539 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1540 }
1541 },
1542 model_name => {
1543 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1544 }
1545 };
1546
1547 stripe_billing
1548 .subscribe_to_price(&stripe_subscription_id, price)
1549 .await?;
1550 stripe_billing
1551 .bill_model_request_usage(
1552 &stripe_customer_id,
1553 meter_event_name,
1554 usage_meter.requests,
1555 )
1556 .await?;
1557
1558 Ok(())
1559 })
1560 .await
1561 .log_err();
1562 }
1563
1564 Ok(())
1565}