1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{
31 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
32};
33use crate::rpc::{ResultExt as _, Server};
34use crate::{AppState, Cents, Error, Result};
35use crate::{db::UserId, llm::db::LlmDatabase};
36use crate::{
37 db::{
38 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
39 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
40 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
41 },
42 stripe_billing::StripeBilling,
43};
44
45pub fn router() -> Router {
46 Router::new()
47 .route(
48 "/billing/preferences",
49 get(get_billing_preferences).put(update_billing_preferences),
50 )
51 .route(
52 "/billing/subscriptions",
53 get(list_billing_subscriptions).post(create_billing_subscription),
54 )
55 .route(
56 "/billing/subscriptions/manage",
57 post(manage_billing_subscription),
58 )
59 .route(
60 "/billing/subscriptions/migrate",
61 post(migrate_to_new_billing),
62 )
63 .route("/billing/monthly_spend", get(get_monthly_spend))
64 .route("/billing/usage", get(get_current_usage))
65}
66
67#[derive(Debug, Deserialize)]
68struct GetBillingPreferencesParams {
69 github_user_id: i32,
70}
71
72#[derive(Debug, Serialize)]
73struct BillingPreferencesResponse {
74 trial_started_at: Option<String>,
75 max_monthly_llm_usage_spending_in_cents: i32,
76 model_request_overages_enabled: bool,
77 model_request_overages_spend_limit_in_cents: i32,
78}
79
80async fn get_billing_preferences(
81 Extension(app): Extension<Arc<AppState>>,
82 Query(params): Query<GetBillingPreferencesParams>,
83) -> Result<Json<BillingPreferencesResponse>> {
84 let user = app
85 .db
86 .get_user_by_github_user_id(params.github_user_id)
87 .await?
88 .ok_or_else(|| anyhow!("user not found"))?;
89
90 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
91 let preferences = app.db.get_billing_preferences(user.id).await?;
92
93 Ok(Json(BillingPreferencesResponse {
94 trial_started_at: billing_customer
95 .and_then(|billing_customer| billing_customer.trial_started_at)
96 .map(|trial_started_at| {
97 trial_started_at
98 .and_utc()
99 .to_rfc3339_opts(SecondsFormat::Millis, true)
100 }),
101 max_monthly_llm_usage_spending_in_cents: preferences
102 .as_ref()
103 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
104 preferences.max_monthly_llm_usage_spending_in_cents
105 }),
106 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
107 preferences.model_request_overages_enabled
108 }),
109 model_request_overages_spend_limit_in_cents: preferences
110 .as_ref()
111 .map_or(0, |preferences| {
112 preferences.model_request_overages_spend_limit_in_cents
113 }),
114 }))
115}
116
117#[derive(Debug, Deserialize)]
118struct UpdateBillingPreferencesBody {
119 github_user_id: i32,
120 #[serde(default)]
121 max_monthly_llm_usage_spending_in_cents: i32,
122 #[serde(default)]
123 model_request_overages_enabled: bool,
124 #[serde(default)]
125 model_request_overages_spend_limit_in_cents: i32,
126}
127
128async fn update_billing_preferences(
129 Extension(app): Extension<Arc<AppState>>,
130 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
131 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
132) -> Result<Json<BillingPreferencesResponse>> {
133 let user = app
134 .db
135 .get_user_by_github_user_id(body.github_user_id)
136 .await?
137 .ok_or_else(|| anyhow!("user not found"))?;
138
139 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
140
141 let max_monthly_llm_usage_spending_in_cents =
142 body.max_monthly_llm_usage_spending_in_cents.max(0);
143 let model_request_overages_spend_limit_in_cents =
144 body.model_request_overages_spend_limit_in_cents.max(0);
145
146 let billing_preferences =
147 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
148 app.db
149 .update_billing_preferences(
150 user.id,
151 &UpdateBillingPreferencesParams {
152 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
153 max_monthly_llm_usage_spending_in_cents,
154 ),
155 model_request_overages_enabled: ActiveValue::set(
156 body.model_request_overages_enabled,
157 ),
158 model_request_overages_spend_limit_in_cents: ActiveValue::set(
159 model_request_overages_spend_limit_in_cents,
160 ),
161 },
162 )
163 .await?
164 } else {
165 app.db
166 .create_billing_preferences(
167 user.id,
168 &crate::db::CreateBillingPreferencesParams {
169 max_monthly_llm_usage_spending_in_cents,
170 model_request_overages_enabled: body.model_request_overages_enabled,
171 model_request_overages_spend_limit_in_cents,
172 },
173 )
174 .await?
175 };
176
177 SnowflakeRow::new(
178 "Billing Preferences Updated",
179 Some(user.metrics_id),
180 user.admin,
181 None,
182 json!({
183 "user_id": user.id,
184 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
185 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
186 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
187 }),
188 )
189 .write(&app.kinesis_client, &app.config.kinesis_stream)
190 .await
191 .log_err();
192
193 rpc_server.refresh_llm_tokens_for_user(user.id).await;
194
195 Ok(Json(BillingPreferencesResponse {
196 trial_started_at: billing_customer
197 .and_then(|billing_customer| billing_customer.trial_started_at)
198 .map(|trial_started_at| {
199 trial_started_at
200 .and_utc()
201 .to_rfc3339_opts(SecondsFormat::Millis, true)
202 }),
203 max_monthly_llm_usage_spending_in_cents: billing_preferences
204 .max_monthly_llm_usage_spending_in_cents,
205 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
206 model_request_overages_spend_limit_in_cents: billing_preferences
207 .model_request_overages_spend_limit_in_cents,
208 }))
209}
210
211#[derive(Debug, Deserialize)]
212struct ListBillingSubscriptionsParams {
213 github_user_id: i32,
214}
215
216#[derive(Debug, Serialize)]
217struct BillingSubscriptionJson {
218 id: BillingSubscriptionId,
219 name: String,
220 status: StripeSubscriptionStatus,
221 trial_end_at: Option<String>,
222 cancel_at: Option<String>,
223 /// Whether this subscription can be canceled.
224 is_cancelable: bool,
225}
226
227#[derive(Debug, Serialize)]
228struct ListBillingSubscriptionsResponse {
229 subscriptions: Vec<BillingSubscriptionJson>,
230}
231
232async fn list_billing_subscriptions(
233 Extension(app): Extension<Arc<AppState>>,
234 Query(params): Query<ListBillingSubscriptionsParams>,
235) -> Result<Json<ListBillingSubscriptionsResponse>> {
236 let user = app
237 .db
238 .get_user_by_github_user_id(params.github_user_id)
239 .await?
240 .ok_or_else(|| anyhow!("user not found"))?;
241
242 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
243
244 Ok(Json(ListBillingSubscriptionsResponse {
245 subscriptions: subscriptions
246 .into_iter()
247 .map(|subscription| BillingSubscriptionJson {
248 id: subscription.id,
249 name: match subscription.kind {
250 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
251 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
252 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
253 None => "Zed LLM Usage".to_string(),
254 },
255 status: subscription.stripe_subscription_status,
256 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
257 maybe!({
258 let end_at = subscription.stripe_current_period_end?;
259 let end_at = DateTime::from_timestamp(end_at, 0)?;
260
261 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
262 })
263 } else {
264 None
265 },
266 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
267 cancel_at
268 .and_utc()
269 .to_rfc3339_opts(SecondsFormat::Millis, true)
270 }),
271 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
272 && subscription.stripe_cancel_at.is_none(),
273 })
274 .collect(),
275 }))
276}
277
278#[derive(Debug, Clone, Copy, Deserialize)]
279#[serde(rename_all = "snake_case")]
280enum ProductCode {
281 ZedPro,
282 ZedProTrial,
283 ZedFree,
284}
285
286#[derive(Debug, Deserialize)]
287struct CreateBillingSubscriptionBody {
288 github_user_id: i32,
289 product: Option<ProductCode>,
290}
291
292#[derive(Debug, Serialize)]
293struct CreateBillingSubscriptionResponse {
294 checkout_session_url: String,
295}
296
297/// Initiates a Stripe Checkout session for creating a billing subscription.
298async fn create_billing_subscription(
299 Extension(app): Extension<Arc<AppState>>,
300 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
301) -> Result<Json<CreateBillingSubscriptionResponse>> {
302 let user = app
303 .db
304 .get_user_by_github_user_id(body.github_user_id)
305 .await?
306 .ok_or_else(|| anyhow!("user not found"))?;
307
308 let Some(stripe_client) = app.stripe_client.clone() else {
309 log::error!("failed to retrieve Stripe client");
310 Err(Error::http(
311 StatusCode::NOT_IMPLEMENTED,
312 "not supported".into(),
313 ))?
314 };
315 let Some(stripe_billing) = app.stripe_billing.clone() else {
316 log::error!("failed to retrieve Stripe billing object");
317 Err(Error::http(
318 StatusCode::NOT_IMPLEMENTED,
319 "not supported".into(),
320 ))?
321 };
322
323 if app.db.has_active_billing_subscription(user.id).await? {
324 return Err(Error::http(
325 StatusCode::CONFLICT,
326 "user already has an active subscription".into(),
327 ));
328 }
329
330 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
331 if let Some(existing_billing_customer) = &existing_billing_customer {
332 if existing_billing_customer.has_overdue_invoices {
333 return Err(Error::http(
334 StatusCode::PAYMENT_REQUIRED,
335 "user has overdue invoices".into(),
336 ));
337 }
338 }
339
340 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
341 CustomerId::from_str(&existing_customer.stripe_customer_id)
342 .context("failed to parse customer ID")?
343 } else {
344 let existing_customer = if let Some(email) = user.email_address.as_deref() {
345 let customers = Customer::list(
346 &stripe_client,
347 &stripe::ListCustomers {
348 email: Some(email),
349 ..Default::default()
350 },
351 )
352 .await?;
353
354 customers.data.first().cloned()
355 } else {
356 None
357 };
358
359 if let Some(existing_customer) = existing_customer {
360 existing_customer.id
361 } else {
362 let customer = Customer::create(
363 &stripe_client,
364 CreateCustomer {
365 email: user.email_address.as_deref(),
366 ..Default::default()
367 },
368 )
369 .await?;
370
371 customer.id
372 }
373 };
374
375 let success_url = format!(
376 "{}/account?checkout_complete=1",
377 app.config.zed_dot_dev_url()
378 );
379
380 let checkout_session_url = match body.product {
381 Some(ProductCode::ZedPro) => {
382 stripe_billing
383 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
384 .await?
385 }
386 Some(ProductCode::ZedProTrial) => {
387 if let Some(existing_billing_customer) = &existing_billing_customer {
388 if existing_billing_customer.trial_started_at.is_some() {
389 return Err(Error::http(
390 StatusCode::FORBIDDEN,
391 "user already used free trial".into(),
392 ));
393 }
394 }
395
396 let feature_flags = app.db.get_user_flags(user.id).await?;
397
398 stripe_billing
399 .checkout_with_zed_pro_trial(
400 customer_id,
401 &user.github_login,
402 feature_flags,
403 &success_url,
404 )
405 .await?
406 }
407 Some(ProductCode::ZedFree) => {
408 stripe_billing
409 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
410 .await?
411 }
412 None => {
413 return Err(Error::http(
414 StatusCode::BAD_REQUEST,
415 "No product selected".into(),
416 ));
417 }
418 };
419
420 Ok(Json(CreateBillingSubscriptionResponse {
421 checkout_session_url,
422 }))
423}
424
425#[derive(Debug, PartialEq, Deserialize)]
426#[serde(rename_all = "snake_case")]
427enum ManageSubscriptionIntent {
428 /// The user intends to manage their subscription.
429 ///
430 /// This will open the Stripe billing portal without putting the user in a specific flow.
431 ManageSubscription,
432 /// The user intends to update their payment method.
433 UpdatePaymentMethod,
434 /// The user intends to upgrade to Zed Pro.
435 UpgradeToPro,
436 /// The user intends to cancel their subscription.
437 Cancel,
438 /// The user intends to stop the cancellation of their subscription.
439 StopCancellation,
440}
441
442#[derive(Debug, Deserialize)]
443struct ManageBillingSubscriptionBody {
444 github_user_id: i32,
445 intent: ManageSubscriptionIntent,
446 /// The ID of the subscription to manage.
447 subscription_id: BillingSubscriptionId,
448}
449
450#[derive(Debug, Serialize)]
451struct ManageBillingSubscriptionResponse {
452 billing_portal_session_url: Option<String>,
453}
454
455/// Initiates a Stripe customer portal session for managing a billing subscription.
456async fn manage_billing_subscription(
457 Extension(app): Extension<Arc<AppState>>,
458 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
459) -> Result<Json<ManageBillingSubscriptionResponse>> {
460 let user = app
461 .db
462 .get_user_by_github_user_id(body.github_user_id)
463 .await?
464 .ok_or_else(|| anyhow!("user not found"))?;
465
466 let Some(stripe_client) = app.stripe_client.clone() else {
467 log::error!("failed to retrieve Stripe client");
468 Err(Error::http(
469 StatusCode::NOT_IMPLEMENTED,
470 "not supported".into(),
471 ))?
472 };
473
474 let Some(stripe_billing) = app.stripe_billing.clone() else {
475 log::error!("failed to retrieve Stripe billing object");
476 Err(Error::http(
477 StatusCode::NOT_IMPLEMENTED,
478 "not supported".into(),
479 ))?
480 };
481
482 let customer = app
483 .db
484 .get_billing_customer_by_user_id(user.id)
485 .await?
486 .ok_or_else(|| anyhow!("billing customer not found"))?;
487 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
488 .context("failed to parse customer ID")?;
489
490 let subscription = app
491 .db
492 .get_billing_subscription_by_id(body.subscription_id)
493 .await?
494 .ok_or_else(|| anyhow!("subscription not found"))?;
495 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
496 .context("failed to parse subscription ID")?;
497
498 if body.intent == ManageSubscriptionIntent::StopCancellation {
499 let updated_stripe_subscription = Subscription::update(
500 &stripe_client,
501 &subscription_id,
502 stripe::UpdateSubscription {
503 cancel_at_period_end: Some(false),
504 ..Default::default()
505 },
506 )
507 .await?;
508
509 app.db
510 .update_billing_subscription(
511 subscription.id,
512 &UpdateBillingSubscriptionParams {
513 stripe_cancel_at: ActiveValue::set(
514 updated_stripe_subscription
515 .cancel_at
516 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
517 .map(|time| time.naive_utc()),
518 ),
519 ..Default::default()
520 },
521 )
522 .await?;
523
524 return Ok(Json(ManageBillingSubscriptionResponse {
525 billing_portal_session_url: None,
526 }));
527 }
528
529 let flow = match body.intent {
530 ManageSubscriptionIntent::ManageSubscription => None,
531 ManageSubscriptionIntent::UpgradeToPro => {
532 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
533 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
534
535 let stripe_subscription =
536 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
537
538 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
539 && stripe_subscription.items.data.iter().any(|item| {
540 item.price
541 .as_ref()
542 .map_or(false, |price| price.id == zed_pro_price_id)
543 });
544 if is_on_zed_pro_trial {
545 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
546 Subscription::update(
547 &stripe_client,
548 &stripe_subscription.id,
549 stripe::UpdateSubscription {
550 trial_end: Some(stripe::Scheduled::now()),
551 ..Default::default()
552 },
553 )
554 .await?;
555
556 return Ok(Json(ManageBillingSubscriptionResponse {
557 billing_portal_session_url: None,
558 }));
559 }
560
561 let subscription_item_to_update = stripe_subscription
562 .items
563 .data
564 .iter()
565 .find_map(|item| {
566 let price = item.price.as_ref()?;
567
568 if price.id == zed_free_price_id {
569 Some(item.id.clone())
570 } else {
571 None
572 }
573 })
574 .ok_or_else(|| anyhow!("No subscription item to update"))?;
575
576 Some(CreateBillingPortalSessionFlowData {
577 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
578 subscription_update_confirm: Some(
579 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
580 subscription: subscription.stripe_subscription_id,
581 items: vec![
582 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
583 id: subscription_item_to_update.to_string(),
584 price: Some(zed_pro_price_id.to_string()),
585 quantity: Some(1),
586 },
587 ],
588 discounts: None,
589 },
590 ),
591 ..Default::default()
592 })
593 }
594 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
595 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
596 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
597 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
598 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
599 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
600 }),
601 ..Default::default()
602 }),
603 ..Default::default()
604 }),
605 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
606 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
607 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
608 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
609 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
610 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
611 }),
612 ..Default::default()
613 }),
614 subscription_cancel: Some(
615 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
616 subscription: subscription.stripe_subscription_id,
617 retention: None,
618 },
619 ),
620 ..Default::default()
621 }),
622 ManageSubscriptionIntent::StopCancellation => unreachable!(),
623 };
624
625 let mut params = CreateBillingPortalSession::new(customer_id);
626 params.flow_data = flow;
627 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
628 params.return_url = Some(&return_url);
629
630 let session = BillingPortalSession::create(&stripe_client, params).await?;
631
632 Ok(Json(ManageBillingSubscriptionResponse {
633 billing_portal_session_url: Some(session.url),
634 }))
635}
636
637#[derive(Debug, Deserialize)]
638struct MigrateToNewBillingBody {
639 github_user_id: i32,
640}
641
642#[derive(Debug, Serialize)]
643struct MigrateToNewBillingResponse {
644 /// The ID of the subscription that was canceled.
645 canceled_subscription_id: Option<String>,
646}
647
648async fn migrate_to_new_billing(
649 Extension(app): Extension<Arc<AppState>>,
650 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
651) -> Result<Json<MigrateToNewBillingResponse>> {
652 let Some(stripe_client) = app.stripe_client.clone() else {
653 log::error!("failed to retrieve Stripe client");
654 Err(Error::http(
655 StatusCode::NOT_IMPLEMENTED,
656 "not supported".into(),
657 ))?
658 };
659
660 let user = app
661 .db
662 .get_user_by_github_user_id(body.github_user_id)
663 .await?
664 .ok_or_else(|| anyhow!("user not found"))?;
665
666 let old_billing_subscriptions_by_user = app
667 .db
668 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
669 .await?;
670
671 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
672 old_billing_subscriptions_by_user.get(&user.id)
673 {
674 let stripe_subscription_id = billing_subscription
675 .stripe_subscription_id
676 .parse::<stripe::SubscriptionId>()
677 .context("failed to parse Stripe subscription ID from database")?;
678
679 Subscription::cancel(
680 &stripe_client,
681 &stripe_subscription_id,
682 stripe::CancelSubscription {
683 invoice_now: Some(true),
684 ..Default::default()
685 },
686 )
687 .await?;
688
689 Some(stripe_subscription_id)
690 } else {
691 None
692 };
693
694 let all_feature_flags = app.db.list_feature_flags().await?;
695 let user_feature_flags = app.db.get_user_flags(user.id).await?;
696
697 for feature_flag in ["new-billing", "assistant2"] {
698 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
699 if already_in_feature_flag {
700 continue;
701 }
702
703 let feature_flag = all_feature_flags
704 .iter()
705 .find(|flag| flag.flag == feature_flag)
706 .context("failed to find feature flag: {feature_flag:?}")?;
707
708 app.db.add_user_flag(user.id, feature_flag.id).await?;
709 }
710
711 Ok(Json(MigrateToNewBillingResponse {
712 canceled_subscription_id: canceled_subscription_id
713 .map(|subscription_id| subscription_id.to_string()),
714 }))
715}
716
717/// The amount of time we wait in between each poll of Stripe events.
718///
719/// This value should strike a balance between:
720/// 1. Being short enough that we update quickly when something in Stripe changes
721/// 2. Being long enough that we don't eat into our rate limits.
722///
723/// As a point of reference, the Sequin folks say they have this at **500ms**:
724///
725/// > We poll the Stripe /events endpoint every 500ms per account
726/// >
727/// > — https://blog.sequinstream.com/events-not-webhooks/
728const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
729
730/// The maximum number of events to return per page.
731///
732/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
733///
734/// > Limit can range between 1 and 100, and the default is 10.
735const EVENTS_LIMIT_PER_PAGE: u64 = 100;
736
737/// The number of pages consisting entirely of already-processed events that we
738/// will see before we stop retrieving events.
739///
740/// This is used to prevent over-fetching the Stripe events API for events we've
741/// already seen and processed.
742const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
743
744/// Polls the Stripe events API periodically to reconcile the records in our
745/// database with the data in Stripe.
746pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
747 let Some(stripe_client) = app.stripe_client.clone() else {
748 log::warn!("failed to retrieve Stripe client");
749 return;
750 };
751
752 let executor = app.executor.clone();
753 executor.spawn_detached({
754 let executor = executor.clone();
755 async move {
756 loop {
757 poll_stripe_events(&app, &rpc_server, &stripe_client)
758 .await
759 .log_err();
760
761 executor.sleep(POLL_EVENTS_INTERVAL).await;
762 }
763 }
764 });
765}
766
767async fn poll_stripe_events(
768 app: &Arc<AppState>,
769 rpc_server: &Arc<Server>,
770 stripe_client: &stripe::Client,
771) -> anyhow::Result<()> {
772 fn event_type_to_string(event_type: EventType) -> String {
773 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
774 // so we need to unquote it.
775 event_type.to_string().trim_matches('"').to_string()
776 }
777
778 let event_types = [
779 EventType::CustomerCreated,
780 EventType::CustomerUpdated,
781 EventType::CustomerSubscriptionCreated,
782 EventType::CustomerSubscriptionUpdated,
783 EventType::CustomerSubscriptionPaused,
784 EventType::CustomerSubscriptionResumed,
785 EventType::CustomerSubscriptionDeleted,
786 ]
787 .into_iter()
788 .map(event_type_to_string)
789 .collect::<Vec<_>>();
790
791 let mut pages_of_already_processed_events = 0;
792 let mut unprocessed_events = Vec::new();
793
794 log::info!(
795 "Stripe events: starting retrieval for {}",
796 event_types.join(", ")
797 );
798 let mut params = ListEvents::new();
799 params.types = Some(event_types.clone());
800 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
801
802 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
803 .await?
804 .paginate(params);
805
806 loop {
807 let processed_event_ids = {
808 let event_ids = event_pages
809 .page
810 .data
811 .iter()
812 .map(|event| event.id.as_str())
813 .collect::<Vec<_>>();
814 app.db
815 .get_processed_stripe_events_by_event_ids(&event_ids)
816 .await?
817 .into_iter()
818 .map(|event| event.stripe_event_id)
819 .collect::<Vec<_>>()
820 };
821
822 let mut processed_events_in_page = 0;
823 let events_in_page = event_pages.page.data.len();
824 for event in &event_pages.page.data {
825 if processed_event_ids.contains(&event.id.to_string()) {
826 processed_events_in_page += 1;
827 log::debug!("Stripe events: already processed '{}', skipping", event.id);
828 } else {
829 unprocessed_events.push(event.clone());
830 }
831 }
832
833 if processed_events_in_page == events_in_page {
834 pages_of_already_processed_events += 1;
835 }
836
837 if event_pages.page.has_more {
838 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
839 {
840 log::info!(
841 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
842 );
843 break;
844 } else {
845 log::info!("Stripe events: retrieving next page");
846 event_pages = event_pages.next(&stripe_client).await?;
847 }
848 } else {
849 break;
850 }
851 }
852
853 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
854
855 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
856 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
857
858 for event in unprocessed_events {
859 let event_id = event.id.clone();
860 let processed_event_params = CreateProcessedStripeEventParams {
861 stripe_event_id: event.id.to_string(),
862 stripe_event_type: event_type_to_string(event.type_),
863 stripe_event_created_timestamp: event.created,
864 };
865
866 // If the event has happened too far in the past, we don't want to
867 // process it and risk overwriting other more-recent updates.
868 //
869 // 1 day was chosen arbitrarily. This could be made longer or shorter.
870 let one_day = Duration::from_secs(24 * 60 * 60);
871 let a_day_ago = Utc::now() - one_day;
872 if a_day_ago.timestamp() > event.created {
873 log::info!(
874 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
875 event_id
876 );
877 app.db
878 .create_processed_stripe_event(&processed_event_params)
879 .await?;
880
881 return Ok(());
882 }
883
884 let process_result = match event.type_ {
885 EventType::CustomerCreated | EventType::CustomerUpdated => {
886 handle_customer_event(app, stripe_client, event).await
887 }
888 EventType::CustomerSubscriptionCreated
889 | EventType::CustomerSubscriptionUpdated
890 | EventType::CustomerSubscriptionPaused
891 | EventType::CustomerSubscriptionResumed
892 | EventType::CustomerSubscriptionDeleted => {
893 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
894 }
895 _ => Ok(()),
896 };
897
898 if let Some(()) = process_result
899 .with_context(|| format!("failed to process event {event_id} successfully"))
900 .log_err()
901 {
902 app.db
903 .create_processed_stripe_event(&processed_event_params)
904 .await?;
905 }
906 }
907
908 Ok(())
909}
910
911async fn handle_customer_event(
912 app: &Arc<AppState>,
913 _stripe_client: &stripe::Client,
914 event: stripe::Event,
915) -> anyhow::Result<()> {
916 let EventObject::Customer(customer) = event.data.object else {
917 bail!("unexpected event payload for {}", event.id);
918 };
919
920 log::info!("handling Stripe {} event: {}", event.type_, event.id);
921
922 let Some(email) = customer.email else {
923 log::info!("Stripe customer has no email: skipping");
924 return Ok(());
925 };
926
927 let Some(user) = app.db.get_user_by_email(&email).await? else {
928 log::info!("no user found for email: skipping");
929 return Ok(());
930 };
931
932 if let Some(existing_customer) = app
933 .db
934 .get_billing_customer_by_stripe_customer_id(&customer.id)
935 .await?
936 {
937 app.db
938 .update_billing_customer(
939 existing_customer.id,
940 &UpdateBillingCustomerParams {
941 // For now we just leave the information as-is, as it is not
942 // likely to change.
943 ..Default::default()
944 },
945 )
946 .await?;
947 } else {
948 app.db
949 .create_billing_customer(&CreateBillingCustomerParams {
950 user_id: user.id,
951 stripe_customer_id: customer.id.to_string(),
952 })
953 .await?;
954 }
955
956 Ok(())
957}
958
959async fn handle_customer_subscription_event(
960 app: &Arc<AppState>,
961 rpc_server: &Arc<Server>,
962 stripe_client: &stripe::Client,
963 event: stripe::Event,
964) -> anyhow::Result<()> {
965 let EventObject::Subscription(subscription) = event.data.object else {
966 bail!("unexpected event payload for {}", event.id);
967 };
968
969 log::info!("handling Stripe {} event: {}", event.type_, event.id);
970
971 let subscription_kind = maybe!(async {
972 let stripe_billing = app.stripe_billing.clone()?;
973
974 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
975 let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
976
977 subscription.items.data.iter().find_map(|item| {
978 let price = item.price.as_ref()?;
979
980 if price.id == zed_pro_price_id {
981 Some(if subscription.status == SubscriptionStatus::Trialing {
982 SubscriptionKind::ZedProTrial
983 } else {
984 SubscriptionKind::ZedPro
985 })
986 } else if price.id == zed_free_price_id {
987 Some(SubscriptionKind::ZedFree)
988 } else {
989 None
990 }
991 })
992 })
993 .await;
994
995 let billing_customer =
996 find_or_create_billing_customer(app, stripe_client, subscription.customer)
997 .await?
998 .ok_or_else(|| anyhow!("billing customer not found"))?;
999
1000 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1001 if subscription.status == SubscriptionStatus::Trialing {
1002 let current_period_start =
1003 DateTime::from_timestamp(subscription.current_period_start, 0)
1004 .ok_or_else(|| anyhow!("No trial subscription period start"))?;
1005
1006 app.db
1007 .update_billing_customer(
1008 billing_customer.id,
1009 &UpdateBillingCustomerParams {
1010 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1011 ..Default::default()
1012 },
1013 )
1014 .await?;
1015 }
1016 }
1017
1018 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1019 && subscription
1020 .cancellation_details
1021 .as_ref()
1022 .and_then(|details| details.reason)
1023 .map_or(false, |reason| {
1024 reason == CancellationDetailsReason::PaymentFailed
1025 });
1026
1027 if was_canceled_due_to_payment_failure {
1028 app.db
1029 .update_billing_customer(
1030 billing_customer.id,
1031 &UpdateBillingCustomerParams {
1032 has_overdue_invoices: ActiveValue::set(true),
1033 ..Default::default()
1034 },
1035 )
1036 .await?;
1037 }
1038
1039 if let Some(existing_subscription) = app
1040 .db
1041 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1042 .await?
1043 {
1044 app.db
1045 .update_billing_subscription(
1046 existing_subscription.id,
1047 &UpdateBillingSubscriptionParams {
1048 billing_customer_id: ActiveValue::set(billing_customer.id),
1049 kind: ActiveValue::set(subscription_kind),
1050 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1051 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1052 stripe_cancel_at: ActiveValue::set(
1053 subscription
1054 .cancel_at
1055 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1056 .map(|time| time.naive_utc()),
1057 ),
1058 stripe_cancellation_reason: ActiveValue::set(
1059 subscription
1060 .cancellation_details
1061 .and_then(|details| details.reason)
1062 .map(|reason| reason.into()),
1063 ),
1064 stripe_current_period_start: ActiveValue::set(Some(
1065 subscription.current_period_start,
1066 )),
1067 stripe_current_period_end: ActiveValue::set(Some(
1068 subscription.current_period_end,
1069 )),
1070 },
1071 )
1072 .await?;
1073 } else {
1074 // If the user already has an active billing subscription, ignore the
1075 // event and return an `Ok` to signal that it was processed
1076 // successfully.
1077 //
1078 // There is the possibility that this could cause us to not create a
1079 // subscription in the following scenario:
1080 //
1081 // 1. User has an active subscription A
1082 // 2. User cancels subscription A
1083 // 3. User creates a new subscription B
1084 // 4. We process the new subscription B before the cancellation of subscription A
1085 // 5. User ends up with no subscriptions
1086 //
1087 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1088 if app
1089 .db
1090 .has_active_billing_subscription(billing_customer.user_id)
1091 .await?
1092 {
1093 log::info!(
1094 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1095 user_id = billing_customer.user_id,
1096 subscription_id = subscription.id
1097 );
1098 return Ok(());
1099 }
1100
1101 app.db
1102 .create_billing_subscription(&CreateBillingSubscriptionParams {
1103 billing_customer_id: billing_customer.id,
1104 kind: subscription_kind,
1105 stripe_subscription_id: subscription.id.to_string(),
1106 stripe_subscription_status: subscription.status.into(),
1107 stripe_cancellation_reason: subscription
1108 .cancellation_details
1109 .and_then(|details| details.reason)
1110 .map(|reason| reason.into()),
1111 stripe_current_period_start: Some(subscription.current_period_start),
1112 stripe_current_period_end: Some(subscription.current_period_end),
1113 })
1114 .await?;
1115 }
1116
1117 // When the user's subscription changes, we want to refresh their LLM tokens
1118 // to either grant/revoke access.
1119 rpc_server
1120 .refresh_llm_tokens_for_user(billing_customer.user_id)
1121 .await;
1122
1123 Ok(())
1124}
1125
1126#[derive(Debug, Deserialize)]
1127struct GetMonthlySpendParams {
1128 github_user_id: i32,
1129}
1130
1131#[derive(Debug, Serialize)]
1132struct GetMonthlySpendResponse {
1133 monthly_free_tier_spend_in_cents: u32,
1134 monthly_free_tier_allowance_in_cents: u32,
1135 monthly_spend_in_cents: u32,
1136}
1137
1138async fn get_monthly_spend(
1139 Extension(app): Extension<Arc<AppState>>,
1140 Query(params): Query<GetMonthlySpendParams>,
1141) -> Result<Json<GetMonthlySpendResponse>> {
1142 let user = app
1143 .db
1144 .get_user_by_github_user_id(params.github_user_id)
1145 .await?
1146 .ok_or_else(|| anyhow!("user not found"))?;
1147
1148 let Some(llm_db) = app.llm_db.clone() else {
1149 return Err(Error::http(
1150 StatusCode::NOT_IMPLEMENTED,
1151 "LLM database not available".into(),
1152 ));
1153 };
1154
1155 let free_tier = user
1156 .custom_llm_monthly_allowance_in_cents
1157 .map(|allowance| Cents(allowance as u32))
1158 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1159
1160 let spending_for_month = llm_db
1161 .get_user_spending_for_month(user.id, Utc::now())
1162 .await?;
1163
1164 let free_tier_spend = Cents::min(spending_for_month, free_tier);
1165 let monthly_spend = spending_for_month.saturating_sub(free_tier);
1166
1167 Ok(Json(GetMonthlySpendResponse {
1168 monthly_free_tier_spend_in_cents: free_tier_spend.0,
1169 monthly_free_tier_allowance_in_cents: free_tier.0,
1170 monthly_spend_in_cents: monthly_spend.0,
1171 }))
1172}
1173
1174#[derive(Debug, Deserialize)]
1175struct GetCurrentUsageParams {
1176 github_user_id: i32,
1177}
1178
1179#[derive(Debug, Serialize)]
1180struct UsageCounts {
1181 pub used: i32,
1182 pub limit: Option<i32>,
1183 pub remaining: Option<i32>,
1184}
1185
1186#[derive(Debug, Serialize)]
1187struct ModelRequestUsage {
1188 pub model: String,
1189 pub mode: CompletionMode,
1190 pub requests: i32,
1191}
1192
1193#[derive(Debug, Serialize)]
1194struct CurrentUsage {
1195 pub model_requests: UsageCounts,
1196 pub model_request_usage: Vec<ModelRequestUsage>,
1197 pub edit_predictions: UsageCounts,
1198}
1199
1200#[derive(Debug, Default, Serialize)]
1201struct GetCurrentUsageResponse {
1202 pub plan: String,
1203 pub current_usage: Option<CurrentUsage>,
1204}
1205
1206async fn get_current_usage(
1207 Extension(app): Extension<Arc<AppState>>,
1208 Query(params): Query<GetCurrentUsageParams>,
1209) -> Result<Json<GetCurrentUsageResponse>> {
1210 let user = app
1211 .db
1212 .get_user_by_github_user_id(params.github_user_id)
1213 .await?
1214 .ok_or_else(|| anyhow!("user not found"))?;
1215
1216 let feature_flags = app.db.get_user_flags(user.id).await?;
1217 let has_extended_trial = feature_flags
1218 .iter()
1219 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1220
1221 let Some(llm_db) = app.llm_db.clone() else {
1222 return Err(Error::http(
1223 StatusCode::NOT_IMPLEMENTED,
1224 "LLM database not available".into(),
1225 ));
1226 };
1227
1228 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1229 return Ok(Json(GetCurrentUsageResponse::default()));
1230 };
1231
1232 let subscription_period = maybe!({
1233 let period_start_at = subscription.current_period_start_at()?;
1234 let period_end_at = subscription.current_period_end_at()?;
1235
1236 Some((period_start_at, period_end_at))
1237 });
1238
1239 let Some((period_start_at, period_end_at)) = subscription_period else {
1240 return Ok(Json(GetCurrentUsageResponse::default()));
1241 };
1242
1243 let usage = llm_db
1244 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1245 .await?;
1246
1247 let plan = usage
1248 .as_ref()
1249 .map(|usage| usage.plan.into())
1250 .unwrap_or_else(|| {
1251 subscription
1252 .kind
1253 .map(Into::into)
1254 .unwrap_or(zed_llm_client::Plan::Free)
1255 });
1256
1257 let model_requests_limit = match plan.model_requests_limit() {
1258 zed_llm_client::UsageLimit::Limited(limit) => {
1259 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1260 1_000
1261 } else {
1262 limit
1263 };
1264
1265 Some(limit)
1266 }
1267 zed_llm_client::UsageLimit::Unlimited => None,
1268 };
1269
1270 let edit_predictions_limit = match plan.edit_predictions_limit() {
1271 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1272 zed_llm_client::UsageLimit::Unlimited => None,
1273 };
1274
1275 let Some(usage) = usage else {
1276 return Ok(Json(GetCurrentUsageResponse {
1277 plan: plan.as_str().to_string(),
1278 current_usage: Some(CurrentUsage {
1279 model_requests: UsageCounts {
1280 used: 0,
1281 limit: model_requests_limit,
1282 remaining: model_requests_limit,
1283 },
1284 model_request_usage: Vec::new(),
1285 edit_predictions: UsageCounts {
1286 used: 0,
1287 limit: edit_predictions_limit,
1288 remaining: edit_predictions_limit,
1289 },
1290 }),
1291 }));
1292 };
1293
1294 let subscription_usage_meters = llm_db
1295 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1296 .await?;
1297
1298 let model_request_usage = subscription_usage_meters
1299 .into_iter()
1300 .filter_map(|(usage_meter, _usage)| {
1301 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1302
1303 Some(ModelRequestUsage {
1304 model: model.name.clone(),
1305 mode: usage_meter.mode,
1306 requests: usage_meter.requests,
1307 })
1308 })
1309 .collect::<Vec<_>>();
1310
1311 Ok(Json(GetCurrentUsageResponse {
1312 plan: plan.as_str().to_string(),
1313 current_usage: Some(CurrentUsage {
1314 model_requests: UsageCounts {
1315 used: usage.model_requests,
1316 limit: model_requests_limit,
1317 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1318 },
1319 model_request_usage,
1320 edit_predictions: UsageCounts {
1321 used: usage.edit_predictions,
1322 limit: edit_predictions_limit,
1323 remaining: edit_predictions_limit
1324 .map(|limit| (limit - usage.edit_predictions).max(0)),
1325 },
1326 }),
1327 }))
1328}
1329
1330impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1331 fn from(value: SubscriptionStatus) -> Self {
1332 match value {
1333 SubscriptionStatus::Incomplete => Self::Incomplete,
1334 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1335 SubscriptionStatus::Trialing => Self::Trialing,
1336 SubscriptionStatus::Active => Self::Active,
1337 SubscriptionStatus::PastDue => Self::PastDue,
1338 SubscriptionStatus::Canceled => Self::Canceled,
1339 SubscriptionStatus::Unpaid => Self::Unpaid,
1340 SubscriptionStatus::Paused => Self::Paused,
1341 }
1342 }
1343}
1344
1345impl From<CancellationDetailsReason> for StripeCancellationReason {
1346 fn from(value: CancellationDetailsReason) -> Self {
1347 match value {
1348 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1349 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1350 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1351 }
1352 }
1353}
1354
1355/// Finds or creates a billing customer using the provided customer.
1356async fn find_or_create_billing_customer(
1357 app: &Arc<AppState>,
1358 stripe_client: &stripe::Client,
1359 customer_or_id: Expandable<Customer>,
1360) -> anyhow::Result<Option<billing_customer::Model>> {
1361 let customer_id = match &customer_or_id {
1362 Expandable::Id(id) => id,
1363 Expandable::Object(customer) => customer.id.as_ref(),
1364 };
1365
1366 // If we already have a billing customer record associated with the Stripe customer,
1367 // there's nothing more we need to do.
1368 if let Some(billing_customer) = app
1369 .db
1370 .get_billing_customer_by_stripe_customer_id(customer_id)
1371 .await?
1372 {
1373 return Ok(Some(billing_customer));
1374 }
1375
1376 // If all we have is a customer ID, resolve it to a full customer record by
1377 // hitting the Stripe API.
1378 let customer = match customer_or_id {
1379 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1380 Expandable::Object(customer) => *customer,
1381 };
1382
1383 let Some(email) = customer.email else {
1384 return Ok(None);
1385 };
1386
1387 let Some(user) = app.db.get_user_by_email(&email).await? else {
1388 return Ok(None);
1389 };
1390
1391 let billing_customer = app
1392 .db
1393 .create_billing_customer(&CreateBillingCustomerParams {
1394 user_id: user.id,
1395 stripe_customer_id: customer.id.to_string(),
1396 })
1397 .await?;
1398
1399 Ok(Some(billing_customer))
1400}
1401
1402const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1403
1404pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1405 let Some(stripe_billing) = app.stripe_billing.clone() else {
1406 log::warn!("failed to retrieve Stripe billing object");
1407 return;
1408 };
1409 let Some(llm_db) = app.llm_db.clone() else {
1410 log::warn!("failed to retrieve LLM database");
1411 return;
1412 };
1413
1414 let executor = app.executor.clone();
1415 executor.spawn_detached({
1416 let executor = executor.clone();
1417 async move {
1418 loop {
1419 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1420 .await
1421 .context("failed to sync LLM request usage to Stripe")
1422 .trace_err();
1423 executor
1424 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1425 .await;
1426 }
1427 }
1428 });
1429}
1430
1431async fn sync_model_request_usage_with_stripe(
1432 app: &Arc<AppState>,
1433 llm_db: &Arc<LlmDatabase>,
1434 stripe_billing: &Arc<StripeBilling>,
1435) -> anyhow::Result<()> {
1436 let staff_users = app.db.get_staff_users().await?;
1437 let staff_user_ids = staff_users
1438 .iter()
1439 .map(|user| user.id)
1440 .collect::<HashSet<UserId>>();
1441
1442 let usage_meters = llm_db
1443 .get_current_subscription_usage_meters(Utc::now())
1444 .await?;
1445 let usage_meters = usage_meters
1446 .into_iter()
1447 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1448 .collect::<Vec<_>>();
1449 let user_ids = usage_meters
1450 .iter()
1451 .map(|(_, usage)| usage.user_id)
1452 .collect::<HashSet<UserId>>();
1453 let billing_subscriptions = app
1454 .db
1455 .get_active_zed_pro_billing_subscriptions(user_ids)
1456 .await?;
1457
1458 let claude_3_5_sonnet = stripe_billing
1459 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1460 .await?;
1461 let claude_3_7_sonnet = stripe_billing
1462 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1463 .await?;
1464 let claude_3_7_sonnet_max = stripe_billing
1465 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1466 .await?;
1467
1468 for (usage_meter, usage) in usage_meters {
1469 maybe!(async {
1470 let Some((billing_customer, billing_subscription)) =
1471 billing_subscriptions.get(&usage.user_id)
1472 else {
1473 bail!(
1474 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1475 usage.user_id
1476 );
1477 };
1478
1479 let stripe_customer_id = billing_customer
1480 .stripe_customer_id
1481 .parse::<stripe::CustomerId>()
1482 .context("failed to parse Stripe customer ID from database")?;
1483 let stripe_subscription_id = billing_subscription
1484 .stripe_subscription_id
1485 .parse::<stripe::SubscriptionId>()
1486 .context("failed to parse Stripe subscription ID from database")?;
1487
1488 let model = llm_db.model_by_id(usage_meter.model_id)?;
1489
1490 let (price, meter_event_name) = match model.name.as_str() {
1491 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1492 "claude-3-7-sonnet" => match usage_meter.mode {
1493 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1494 CompletionMode::Max => {
1495 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1496 }
1497 },
1498 model_name => {
1499 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1500 }
1501 };
1502
1503 stripe_billing
1504 .subscribe_to_price(&stripe_subscription_id, price)
1505 .await?;
1506 stripe_billing
1507 .bill_model_request_usage(
1508 &stripe_customer_id,
1509 meter_event_name,
1510 usage_meter.requests,
1511 )
1512 .await?;
1513
1514 Ok(())
1515 })
1516 .await
1517 .log_err();
1518 }
1519
1520 Ok(())
1521}