1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{
31 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
32};
33use crate::rpc::{ResultExt as _, Server};
34use crate::{AppState, Cents, Error, Result};
35use crate::{db::UserId, llm::db::LlmDatabase};
36use crate::{
37 db::{
38 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
39 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
40 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
41 },
42 stripe_billing::StripeBilling,
43};
44
45pub fn router() -> Router {
46 Router::new()
47 .route(
48 "/billing/preferences",
49 get(get_billing_preferences).put(update_billing_preferences),
50 )
51 .route(
52 "/billing/subscriptions",
53 get(list_billing_subscriptions).post(create_billing_subscription),
54 )
55 .route(
56 "/billing/subscriptions/manage",
57 post(manage_billing_subscription),
58 )
59 .route(
60 "/billing/subscriptions/migrate",
61 post(migrate_to_new_billing),
62 )
63 .route("/billing/monthly_spend", get(get_monthly_spend))
64 .route("/billing/usage", get(get_current_usage))
65}
66
67#[derive(Debug, Deserialize)]
68struct GetBillingPreferencesParams {
69 github_user_id: i32,
70}
71
72#[derive(Debug, Serialize)]
73struct BillingPreferencesResponse {
74 trial_started_at: Option<String>,
75 max_monthly_llm_usage_spending_in_cents: i32,
76 model_request_overages_enabled: bool,
77 model_request_overages_spend_limit_in_cents: i32,
78}
79
80async fn get_billing_preferences(
81 Extension(app): Extension<Arc<AppState>>,
82 Query(params): Query<GetBillingPreferencesParams>,
83) -> Result<Json<BillingPreferencesResponse>> {
84 let user = app
85 .db
86 .get_user_by_github_user_id(params.github_user_id)
87 .await?
88 .ok_or_else(|| anyhow!("user not found"))?;
89
90 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
91 let preferences = app.db.get_billing_preferences(user.id).await?;
92
93 Ok(Json(BillingPreferencesResponse {
94 trial_started_at: billing_customer
95 .and_then(|billing_customer| billing_customer.trial_started_at)
96 .map(|trial_started_at| {
97 trial_started_at
98 .and_utc()
99 .to_rfc3339_opts(SecondsFormat::Millis, true)
100 }),
101 max_monthly_llm_usage_spending_in_cents: preferences
102 .as_ref()
103 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
104 preferences.max_monthly_llm_usage_spending_in_cents
105 }),
106 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
107 preferences.model_request_overages_enabled
108 }),
109 model_request_overages_spend_limit_in_cents: preferences
110 .as_ref()
111 .map_or(0, |preferences| {
112 preferences.model_request_overages_spend_limit_in_cents
113 }),
114 }))
115}
116
117#[derive(Debug, Deserialize)]
118struct UpdateBillingPreferencesBody {
119 github_user_id: i32,
120 #[serde(default)]
121 max_monthly_llm_usage_spending_in_cents: i32,
122 #[serde(default)]
123 model_request_overages_enabled: bool,
124 #[serde(default)]
125 model_request_overages_spend_limit_in_cents: i32,
126}
127
128async fn update_billing_preferences(
129 Extension(app): Extension<Arc<AppState>>,
130 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
131 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
132) -> Result<Json<BillingPreferencesResponse>> {
133 let user = app
134 .db
135 .get_user_by_github_user_id(body.github_user_id)
136 .await?
137 .ok_or_else(|| anyhow!("user not found"))?;
138
139 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
140
141 let max_monthly_llm_usage_spending_in_cents =
142 body.max_monthly_llm_usage_spending_in_cents.max(0);
143 let model_request_overages_spend_limit_in_cents =
144 body.model_request_overages_spend_limit_in_cents.max(0);
145
146 let billing_preferences =
147 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
148 app.db
149 .update_billing_preferences(
150 user.id,
151 &UpdateBillingPreferencesParams {
152 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
153 max_monthly_llm_usage_spending_in_cents,
154 ),
155 model_request_overages_enabled: ActiveValue::set(
156 body.model_request_overages_enabled,
157 ),
158 model_request_overages_spend_limit_in_cents: ActiveValue::set(
159 model_request_overages_spend_limit_in_cents,
160 ),
161 },
162 )
163 .await?
164 } else {
165 app.db
166 .create_billing_preferences(
167 user.id,
168 &crate::db::CreateBillingPreferencesParams {
169 max_monthly_llm_usage_spending_in_cents,
170 model_request_overages_enabled: body.model_request_overages_enabled,
171 model_request_overages_spend_limit_in_cents,
172 },
173 )
174 .await?
175 };
176
177 SnowflakeRow::new(
178 "Billing Preferences Updated",
179 Some(user.metrics_id),
180 user.admin,
181 None,
182 json!({
183 "user_id": user.id,
184 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
185 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
186 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
187 }),
188 )
189 .write(&app.kinesis_client, &app.config.kinesis_stream)
190 .await
191 .log_err();
192
193 rpc_server.refresh_llm_tokens_for_user(user.id).await;
194
195 Ok(Json(BillingPreferencesResponse {
196 trial_started_at: billing_customer
197 .and_then(|billing_customer| billing_customer.trial_started_at)
198 .map(|trial_started_at| {
199 trial_started_at
200 .and_utc()
201 .to_rfc3339_opts(SecondsFormat::Millis, true)
202 }),
203 max_monthly_llm_usage_spending_in_cents: billing_preferences
204 .max_monthly_llm_usage_spending_in_cents,
205 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
206 model_request_overages_spend_limit_in_cents: billing_preferences
207 .model_request_overages_spend_limit_in_cents,
208 }))
209}
210
211#[derive(Debug, Deserialize)]
212struct ListBillingSubscriptionsParams {
213 github_user_id: i32,
214}
215
216#[derive(Debug, Serialize)]
217struct BillingSubscriptionJson {
218 id: BillingSubscriptionId,
219 name: String,
220 status: StripeSubscriptionStatus,
221 trial_end_at: Option<String>,
222 cancel_at: Option<String>,
223 /// Whether this subscription can be canceled.
224 is_cancelable: bool,
225}
226
227#[derive(Debug, Serialize)]
228struct ListBillingSubscriptionsResponse {
229 subscriptions: Vec<BillingSubscriptionJson>,
230}
231
232async fn list_billing_subscriptions(
233 Extension(app): Extension<Arc<AppState>>,
234 Query(params): Query<ListBillingSubscriptionsParams>,
235) -> Result<Json<ListBillingSubscriptionsResponse>> {
236 let user = app
237 .db
238 .get_user_by_github_user_id(params.github_user_id)
239 .await?
240 .ok_or_else(|| anyhow!("user not found"))?;
241
242 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
243
244 Ok(Json(ListBillingSubscriptionsResponse {
245 subscriptions: subscriptions
246 .into_iter()
247 .map(|subscription| BillingSubscriptionJson {
248 id: subscription.id,
249 name: match subscription.kind {
250 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
251 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
252 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
253 None => "Zed LLM Usage".to_string(),
254 },
255 status: subscription.stripe_subscription_status,
256 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
257 maybe!({
258 let end_at = subscription.stripe_current_period_end?;
259 let end_at = DateTime::from_timestamp(end_at, 0)?;
260
261 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
262 })
263 } else {
264 None
265 },
266 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
267 cancel_at
268 .and_utc()
269 .to_rfc3339_opts(SecondsFormat::Millis, true)
270 }),
271 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
272 && subscription.stripe_cancel_at.is_none(),
273 })
274 .collect(),
275 }))
276}
277
278#[derive(Debug, Clone, Copy, Deserialize)]
279#[serde(rename_all = "snake_case")]
280enum ProductCode {
281 ZedPro,
282 ZedProTrial,
283 ZedFree,
284}
285
286#[derive(Debug, Deserialize)]
287struct CreateBillingSubscriptionBody {
288 github_user_id: i32,
289 product: Option<ProductCode>,
290}
291
292#[derive(Debug, Serialize)]
293struct CreateBillingSubscriptionResponse {
294 checkout_session_url: String,
295}
296
297/// Initiates a Stripe Checkout session for creating a billing subscription.
298async fn create_billing_subscription(
299 Extension(app): Extension<Arc<AppState>>,
300 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
301) -> Result<Json<CreateBillingSubscriptionResponse>> {
302 let user = app
303 .db
304 .get_user_by_github_user_id(body.github_user_id)
305 .await?
306 .ok_or_else(|| anyhow!("user not found"))?;
307
308 let Some(stripe_client) = app.stripe_client.clone() else {
309 log::error!("failed to retrieve Stripe client");
310 Err(Error::http(
311 StatusCode::NOT_IMPLEMENTED,
312 "not supported".into(),
313 ))?
314 };
315 let Some(stripe_billing) = app.stripe_billing.clone() else {
316 log::error!("failed to retrieve Stripe billing object");
317 Err(Error::http(
318 StatusCode::NOT_IMPLEMENTED,
319 "not supported".into(),
320 ))?
321 };
322
323 if app.db.has_active_billing_subscription(user.id).await? {
324 return Err(Error::http(
325 StatusCode::CONFLICT,
326 "user already has an active subscription".into(),
327 ));
328 }
329
330 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
331 if let Some(existing_billing_customer) = &existing_billing_customer {
332 if existing_billing_customer.has_overdue_invoices {
333 return Err(Error::http(
334 StatusCode::PAYMENT_REQUIRED,
335 "user has overdue invoices".into(),
336 ));
337 }
338 }
339
340 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
341 CustomerId::from_str(&existing_customer.stripe_customer_id)
342 .context("failed to parse customer ID")?
343 } else {
344 let existing_customer = if let Some(email) = user.email_address.as_deref() {
345 let customers = Customer::list(
346 &stripe_client,
347 &stripe::ListCustomers {
348 email: Some(email),
349 ..Default::default()
350 },
351 )
352 .await?;
353
354 customers.data.first().cloned()
355 } else {
356 None
357 };
358
359 if let Some(existing_customer) = existing_customer {
360 existing_customer.id
361 } else {
362 let customer = Customer::create(
363 &stripe_client,
364 CreateCustomer {
365 email: user.email_address.as_deref(),
366 ..Default::default()
367 },
368 )
369 .await?;
370
371 customer.id
372 }
373 };
374
375 let success_url = format!(
376 "{}/account?checkout_complete=1",
377 app.config.zed_dot_dev_url()
378 );
379
380 let checkout_session_url = match body.product {
381 Some(ProductCode::ZedPro) => {
382 stripe_billing
383 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
384 .await?
385 }
386 Some(ProductCode::ZedProTrial) => {
387 if let Some(existing_billing_customer) = &existing_billing_customer {
388 if existing_billing_customer.trial_started_at.is_some() {
389 return Err(Error::http(
390 StatusCode::FORBIDDEN,
391 "user already used free trial".into(),
392 ));
393 }
394 }
395
396 let feature_flags = app.db.get_user_flags(user.id).await?;
397
398 stripe_billing
399 .checkout_with_zed_pro_trial(
400 customer_id,
401 &user.github_login,
402 feature_flags,
403 &success_url,
404 )
405 .await?
406 }
407 Some(ProductCode::ZedFree) => {
408 stripe_billing
409 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
410 .await?
411 }
412 None => {
413 return Err(Error::http(
414 StatusCode::BAD_REQUEST,
415 "No product selected".into(),
416 ));
417 }
418 };
419
420 Ok(Json(CreateBillingSubscriptionResponse {
421 checkout_session_url,
422 }))
423}
424
425#[derive(Debug, PartialEq, Deserialize)]
426#[serde(rename_all = "snake_case")]
427enum ManageSubscriptionIntent {
428 /// The user intends to manage their subscription.
429 ///
430 /// This will open the Stripe billing portal without putting the user in a specific flow.
431 ManageSubscription,
432 /// The user intends to upgrade to Zed Pro.
433 UpgradeToPro,
434 /// The user intends to cancel their subscription.
435 Cancel,
436 /// The user intends to stop the cancellation of their subscription.
437 StopCancellation,
438}
439
440#[derive(Debug, Deserialize)]
441struct ManageBillingSubscriptionBody {
442 github_user_id: i32,
443 intent: ManageSubscriptionIntent,
444 /// The ID of the subscription to manage.
445 subscription_id: BillingSubscriptionId,
446}
447
448#[derive(Debug, Serialize)]
449struct ManageBillingSubscriptionResponse {
450 billing_portal_session_url: Option<String>,
451}
452
453/// Initiates a Stripe customer portal session for managing a billing subscription.
454async fn manage_billing_subscription(
455 Extension(app): Extension<Arc<AppState>>,
456 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
457) -> Result<Json<ManageBillingSubscriptionResponse>> {
458 let user = app
459 .db
460 .get_user_by_github_user_id(body.github_user_id)
461 .await?
462 .ok_or_else(|| anyhow!("user not found"))?;
463
464 let Some(stripe_client) = app.stripe_client.clone() else {
465 log::error!("failed to retrieve Stripe client");
466 Err(Error::http(
467 StatusCode::NOT_IMPLEMENTED,
468 "not supported".into(),
469 ))?
470 };
471
472 let Some(stripe_billing) = app.stripe_billing.clone() else {
473 log::error!("failed to retrieve Stripe billing object");
474 Err(Error::http(
475 StatusCode::NOT_IMPLEMENTED,
476 "not supported".into(),
477 ))?
478 };
479
480 let customer = app
481 .db
482 .get_billing_customer_by_user_id(user.id)
483 .await?
484 .ok_or_else(|| anyhow!("billing customer not found"))?;
485 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
486 .context("failed to parse customer ID")?;
487
488 let subscription = app
489 .db
490 .get_billing_subscription_by_id(body.subscription_id)
491 .await?
492 .ok_or_else(|| anyhow!("subscription not found"))?;
493 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
494 .context("failed to parse subscription ID")?;
495
496 if body.intent == ManageSubscriptionIntent::StopCancellation {
497 let updated_stripe_subscription = Subscription::update(
498 &stripe_client,
499 &subscription_id,
500 stripe::UpdateSubscription {
501 cancel_at_period_end: Some(false),
502 ..Default::default()
503 },
504 )
505 .await?;
506
507 app.db
508 .update_billing_subscription(
509 subscription.id,
510 &UpdateBillingSubscriptionParams {
511 stripe_cancel_at: ActiveValue::set(
512 updated_stripe_subscription
513 .cancel_at
514 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
515 .map(|time| time.naive_utc()),
516 ),
517 ..Default::default()
518 },
519 )
520 .await?;
521
522 return Ok(Json(ManageBillingSubscriptionResponse {
523 billing_portal_session_url: None,
524 }));
525 }
526
527 let flow = match body.intent {
528 ManageSubscriptionIntent::ManageSubscription => None,
529 ManageSubscriptionIntent::UpgradeToPro => {
530 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
531 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
532
533 let stripe_subscription =
534 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
535
536 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
537 && stripe_subscription.items.data.iter().any(|item| {
538 item.price
539 .as_ref()
540 .map_or(false, |price| price.id == zed_pro_price_id)
541 });
542 if is_on_zed_pro_trial {
543 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
544 Subscription::update(
545 &stripe_client,
546 &stripe_subscription.id,
547 stripe::UpdateSubscription {
548 trial_end: Some(stripe::Scheduled::now()),
549 ..Default::default()
550 },
551 )
552 .await?;
553
554 return Ok(Json(ManageBillingSubscriptionResponse {
555 billing_portal_session_url: None,
556 }));
557 }
558
559 let subscription_item_to_update = stripe_subscription
560 .items
561 .data
562 .iter()
563 .find_map(|item| {
564 let price = item.price.as_ref()?;
565
566 if price.id == zed_free_price_id {
567 Some(item.id.clone())
568 } else {
569 None
570 }
571 })
572 .ok_or_else(|| anyhow!("No subscription item to update"))?;
573
574 Some(CreateBillingPortalSessionFlowData {
575 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
576 subscription_update_confirm: Some(
577 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
578 subscription: subscription.stripe_subscription_id,
579 items: vec![
580 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
581 id: subscription_item_to_update.to_string(),
582 price: Some(zed_pro_price_id.to_string()),
583 quantity: Some(1),
584 },
585 ],
586 discounts: None,
587 },
588 ),
589 ..Default::default()
590 })
591 }
592 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
593 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
594 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
595 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
596 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
597 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
598 }),
599 ..Default::default()
600 }),
601 subscription_cancel: Some(
602 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
603 subscription: subscription.stripe_subscription_id,
604 retention: None,
605 },
606 ),
607 ..Default::default()
608 }),
609 ManageSubscriptionIntent::StopCancellation => unreachable!(),
610 };
611
612 let mut params = CreateBillingPortalSession::new(customer_id);
613 params.flow_data = flow;
614 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
615 params.return_url = Some(&return_url);
616
617 let session = BillingPortalSession::create(&stripe_client, params).await?;
618
619 Ok(Json(ManageBillingSubscriptionResponse {
620 billing_portal_session_url: Some(session.url),
621 }))
622}
623
624#[derive(Debug, Deserialize)]
625struct MigrateToNewBillingBody {
626 github_user_id: i32,
627}
628
629#[derive(Debug, Serialize)]
630struct MigrateToNewBillingResponse {
631 /// The ID of the subscription that was canceled.
632 canceled_subscription_id: Option<String>,
633}
634
635async fn migrate_to_new_billing(
636 Extension(app): Extension<Arc<AppState>>,
637 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
638) -> Result<Json<MigrateToNewBillingResponse>> {
639 let Some(stripe_client) = app.stripe_client.clone() else {
640 log::error!("failed to retrieve Stripe client");
641 Err(Error::http(
642 StatusCode::NOT_IMPLEMENTED,
643 "not supported".into(),
644 ))?
645 };
646
647 let user = app
648 .db
649 .get_user_by_github_user_id(body.github_user_id)
650 .await?
651 .ok_or_else(|| anyhow!("user not found"))?;
652
653 let old_billing_subscriptions_by_user = app
654 .db
655 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
656 .await?;
657
658 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
659 old_billing_subscriptions_by_user.get(&user.id)
660 {
661 let stripe_subscription_id = billing_subscription
662 .stripe_subscription_id
663 .parse::<stripe::SubscriptionId>()
664 .context("failed to parse Stripe subscription ID from database")?;
665
666 Subscription::cancel(
667 &stripe_client,
668 &stripe_subscription_id,
669 stripe::CancelSubscription {
670 invoice_now: Some(true),
671 ..Default::default()
672 },
673 )
674 .await?;
675
676 Some(stripe_subscription_id)
677 } else {
678 None
679 };
680
681 let all_feature_flags = app.db.list_feature_flags().await?;
682 let user_feature_flags = app.db.get_user_flags(user.id).await?;
683
684 for feature_flag in ["new-billing", "assistant2"] {
685 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
686 if already_in_feature_flag {
687 continue;
688 }
689
690 let feature_flag = all_feature_flags
691 .iter()
692 .find(|flag| flag.flag == feature_flag)
693 .context("failed to find feature flag: {feature_flag:?}")?;
694
695 app.db.add_user_flag(user.id, feature_flag.id).await?;
696 }
697
698 Ok(Json(MigrateToNewBillingResponse {
699 canceled_subscription_id: canceled_subscription_id
700 .map(|subscription_id| subscription_id.to_string()),
701 }))
702}
703
704/// The amount of time we wait in between each poll of Stripe events.
705///
706/// This value should strike a balance between:
707/// 1. Being short enough that we update quickly when something in Stripe changes
708/// 2. Being long enough that we don't eat into our rate limits.
709///
710/// As a point of reference, the Sequin folks say they have this at **500ms**:
711///
712/// > We poll the Stripe /events endpoint every 500ms per account
713/// >
714/// > — https://blog.sequinstream.com/events-not-webhooks/
715const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
716
717/// The maximum number of events to return per page.
718///
719/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
720///
721/// > Limit can range between 1 and 100, and the default is 10.
722const EVENTS_LIMIT_PER_PAGE: u64 = 100;
723
724/// The number of pages consisting entirely of already-processed events that we
725/// will see before we stop retrieving events.
726///
727/// This is used to prevent over-fetching the Stripe events API for events we've
728/// already seen and processed.
729const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
730
731/// Polls the Stripe events API periodically to reconcile the records in our
732/// database with the data in Stripe.
733pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
734 let Some(stripe_client) = app.stripe_client.clone() else {
735 log::warn!("failed to retrieve Stripe client");
736 return;
737 };
738
739 let executor = app.executor.clone();
740 executor.spawn_detached({
741 let executor = executor.clone();
742 async move {
743 loop {
744 poll_stripe_events(&app, &rpc_server, &stripe_client)
745 .await
746 .log_err();
747
748 executor.sleep(POLL_EVENTS_INTERVAL).await;
749 }
750 }
751 });
752}
753
754async fn poll_stripe_events(
755 app: &Arc<AppState>,
756 rpc_server: &Arc<Server>,
757 stripe_client: &stripe::Client,
758) -> anyhow::Result<()> {
759 fn event_type_to_string(event_type: EventType) -> String {
760 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
761 // so we need to unquote it.
762 event_type.to_string().trim_matches('"').to_string()
763 }
764
765 let event_types = [
766 EventType::CustomerCreated,
767 EventType::CustomerUpdated,
768 EventType::CustomerSubscriptionCreated,
769 EventType::CustomerSubscriptionUpdated,
770 EventType::CustomerSubscriptionPaused,
771 EventType::CustomerSubscriptionResumed,
772 EventType::CustomerSubscriptionDeleted,
773 ]
774 .into_iter()
775 .map(event_type_to_string)
776 .collect::<Vec<_>>();
777
778 let mut pages_of_already_processed_events = 0;
779 let mut unprocessed_events = Vec::new();
780
781 log::info!(
782 "Stripe events: starting retrieval for {}",
783 event_types.join(", ")
784 );
785 let mut params = ListEvents::new();
786 params.types = Some(event_types.clone());
787 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
788
789 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
790 .await?
791 .paginate(params);
792
793 loop {
794 let processed_event_ids = {
795 let event_ids = event_pages
796 .page
797 .data
798 .iter()
799 .map(|event| event.id.as_str())
800 .collect::<Vec<_>>();
801 app.db
802 .get_processed_stripe_events_by_event_ids(&event_ids)
803 .await?
804 .into_iter()
805 .map(|event| event.stripe_event_id)
806 .collect::<Vec<_>>()
807 };
808
809 let mut processed_events_in_page = 0;
810 let events_in_page = event_pages.page.data.len();
811 for event in &event_pages.page.data {
812 if processed_event_ids.contains(&event.id.to_string()) {
813 processed_events_in_page += 1;
814 log::debug!("Stripe events: already processed '{}', skipping", event.id);
815 } else {
816 unprocessed_events.push(event.clone());
817 }
818 }
819
820 if processed_events_in_page == events_in_page {
821 pages_of_already_processed_events += 1;
822 }
823
824 if event_pages.page.has_more {
825 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
826 {
827 log::info!(
828 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
829 );
830 break;
831 } else {
832 log::info!("Stripe events: retrieving next page");
833 event_pages = event_pages.next(&stripe_client).await?;
834 }
835 } else {
836 break;
837 }
838 }
839
840 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
841
842 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
843 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
844
845 for event in unprocessed_events {
846 let event_id = event.id.clone();
847 let processed_event_params = CreateProcessedStripeEventParams {
848 stripe_event_id: event.id.to_string(),
849 stripe_event_type: event_type_to_string(event.type_),
850 stripe_event_created_timestamp: event.created,
851 };
852
853 // If the event has happened too far in the past, we don't want to
854 // process it and risk overwriting other more-recent updates.
855 //
856 // 1 day was chosen arbitrarily. This could be made longer or shorter.
857 let one_day = Duration::from_secs(24 * 60 * 60);
858 let a_day_ago = Utc::now() - one_day;
859 if a_day_ago.timestamp() > event.created {
860 log::info!(
861 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
862 event_id
863 );
864 app.db
865 .create_processed_stripe_event(&processed_event_params)
866 .await?;
867
868 return Ok(());
869 }
870
871 let process_result = match event.type_ {
872 EventType::CustomerCreated | EventType::CustomerUpdated => {
873 handle_customer_event(app, stripe_client, event).await
874 }
875 EventType::CustomerSubscriptionCreated
876 | EventType::CustomerSubscriptionUpdated
877 | EventType::CustomerSubscriptionPaused
878 | EventType::CustomerSubscriptionResumed
879 | EventType::CustomerSubscriptionDeleted => {
880 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
881 }
882 _ => Ok(()),
883 };
884
885 if let Some(()) = process_result
886 .with_context(|| format!("failed to process event {event_id} successfully"))
887 .log_err()
888 {
889 app.db
890 .create_processed_stripe_event(&processed_event_params)
891 .await?;
892 }
893 }
894
895 Ok(())
896}
897
898async fn handle_customer_event(
899 app: &Arc<AppState>,
900 _stripe_client: &stripe::Client,
901 event: stripe::Event,
902) -> anyhow::Result<()> {
903 let EventObject::Customer(customer) = event.data.object else {
904 bail!("unexpected event payload for {}", event.id);
905 };
906
907 log::info!("handling Stripe {} event: {}", event.type_, event.id);
908
909 let Some(email) = customer.email else {
910 log::info!("Stripe customer has no email: skipping");
911 return Ok(());
912 };
913
914 let Some(user) = app.db.get_user_by_email(&email).await? else {
915 log::info!("no user found for email: skipping");
916 return Ok(());
917 };
918
919 if let Some(existing_customer) = app
920 .db
921 .get_billing_customer_by_stripe_customer_id(&customer.id)
922 .await?
923 {
924 app.db
925 .update_billing_customer(
926 existing_customer.id,
927 &UpdateBillingCustomerParams {
928 // For now we just leave the information as-is, as it is not
929 // likely to change.
930 ..Default::default()
931 },
932 )
933 .await?;
934 } else {
935 app.db
936 .create_billing_customer(&CreateBillingCustomerParams {
937 user_id: user.id,
938 stripe_customer_id: customer.id.to_string(),
939 })
940 .await?;
941 }
942
943 Ok(())
944}
945
946async fn handle_customer_subscription_event(
947 app: &Arc<AppState>,
948 rpc_server: &Arc<Server>,
949 stripe_client: &stripe::Client,
950 event: stripe::Event,
951) -> anyhow::Result<()> {
952 let EventObject::Subscription(subscription) = event.data.object else {
953 bail!("unexpected event payload for {}", event.id);
954 };
955
956 log::info!("handling Stripe {} event: {}", event.type_, event.id);
957
958 let subscription_kind = maybe!(async {
959 let stripe_billing = app.stripe_billing.clone()?;
960
961 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
962 let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
963
964 subscription.items.data.iter().find_map(|item| {
965 let price = item.price.as_ref()?;
966
967 if price.id == zed_pro_price_id {
968 Some(if subscription.status == SubscriptionStatus::Trialing {
969 SubscriptionKind::ZedProTrial
970 } else {
971 SubscriptionKind::ZedPro
972 })
973 } else if price.id == zed_free_price_id {
974 Some(SubscriptionKind::ZedFree)
975 } else {
976 None
977 }
978 })
979 })
980 .await;
981
982 let billing_customer =
983 find_or_create_billing_customer(app, stripe_client, subscription.customer)
984 .await?
985 .ok_or_else(|| anyhow!("billing customer not found"))?;
986
987 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
988 if subscription.status == SubscriptionStatus::Trialing {
989 let current_period_start =
990 DateTime::from_timestamp(subscription.current_period_start, 0)
991 .ok_or_else(|| anyhow!("No trial subscription period start"))?;
992
993 app.db
994 .update_billing_customer(
995 billing_customer.id,
996 &UpdateBillingCustomerParams {
997 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
998 ..Default::default()
999 },
1000 )
1001 .await?;
1002 }
1003 }
1004
1005 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1006 && subscription
1007 .cancellation_details
1008 .as_ref()
1009 .and_then(|details| details.reason)
1010 .map_or(false, |reason| {
1011 reason == CancellationDetailsReason::PaymentFailed
1012 });
1013
1014 if was_canceled_due_to_payment_failure {
1015 app.db
1016 .update_billing_customer(
1017 billing_customer.id,
1018 &UpdateBillingCustomerParams {
1019 has_overdue_invoices: ActiveValue::set(true),
1020 ..Default::default()
1021 },
1022 )
1023 .await?;
1024 }
1025
1026 if let Some(existing_subscription) = app
1027 .db
1028 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1029 .await?
1030 {
1031 app.db
1032 .update_billing_subscription(
1033 existing_subscription.id,
1034 &UpdateBillingSubscriptionParams {
1035 billing_customer_id: ActiveValue::set(billing_customer.id),
1036 kind: ActiveValue::set(subscription_kind),
1037 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1038 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1039 stripe_cancel_at: ActiveValue::set(
1040 subscription
1041 .cancel_at
1042 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1043 .map(|time| time.naive_utc()),
1044 ),
1045 stripe_cancellation_reason: ActiveValue::set(
1046 subscription
1047 .cancellation_details
1048 .and_then(|details| details.reason)
1049 .map(|reason| reason.into()),
1050 ),
1051 stripe_current_period_start: ActiveValue::set(Some(
1052 subscription.current_period_start,
1053 )),
1054 stripe_current_period_end: ActiveValue::set(Some(
1055 subscription.current_period_end,
1056 )),
1057 },
1058 )
1059 .await?;
1060 } else {
1061 // If the user already has an active billing subscription, ignore the
1062 // event and return an `Ok` to signal that it was processed
1063 // successfully.
1064 //
1065 // There is the possibility that this could cause us to not create a
1066 // subscription in the following scenario:
1067 //
1068 // 1. User has an active subscription A
1069 // 2. User cancels subscription A
1070 // 3. User creates a new subscription B
1071 // 4. We process the new subscription B before the cancellation of subscription A
1072 // 5. User ends up with no subscriptions
1073 //
1074 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1075 if app
1076 .db
1077 .has_active_billing_subscription(billing_customer.user_id)
1078 .await?
1079 {
1080 log::info!(
1081 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1082 user_id = billing_customer.user_id,
1083 subscription_id = subscription.id
1084 );
1085 return Ok(());
1086 }
1087
1088 app.db
1089 .create_billing_subscription(&CreateBillingSubscriptionParams {
1090 billing_customer_id: billing_customer.id,
1091 kind: subscription_kind,
1092 stripe_subscription_id: subscription.id.to_string(),
1093 stripe_subscription_status: subscription.status.into(),
1094 stripe_cancellation_reason: subscription
1095 .cancellation_details
1096 .and_then(|details| details.reason)
1097 .map(|reason| reason.into()),
1098 stripe_current_period_start: Some(subscription.current_period_start),
1099 stripe_current_period_end: Some(subscription.current_period_end),
1100 })
1101 .await?;
1102 }
1103
1104 // When the user's subscription changes, we want to refresh their LLM tokens
1105 // to either grant/revoke access.
1106 rpc_server
1107 .refresh_llm_tokens_for_user(billing_customer.user_id)
1108 .await;
1109
1110 Ok(())
1111}
1112
1113#[derive(Debug, Deserialize)]
1114struct GetMonthlySpendParams {
1115 github_user_id: i32,
1116}
1117
1118#[derive(Debug, Serialize)]
1119struct GetMonthlySpendResponse {
1120 monthly_free_tier_spend_in_cents: u32,
1121 monthly_free_tier_allowance_in_cents: u32,
1122 monthly_spend_in_cents: u32,
1123}
1124
1125async fn get_monthly_spend(
1126 Extension(app): Extension<Arc<AppState>>,
1127 Query(params): Query<GetMonthlySpendParams>,
1128) -> Result<Json<GetMonthlySpendResponse>> {
1129 let user = app
1130 .db
1131 .get_user_by_github_user_id(params.github_user_id)
1132 .await?
1133 .ok_or_else(|| anyhow!("user not found"))?;
1134
1135 let Some(llm_db) = app.llm_db.clone() else {
1136 return Err(Error::http(
1137 StatusCode::NOT_IMPLEMENTED,
1138 "LLM database not available".into(),
1139 ));
1140 };
1141
1142 let free_tier = user
1143 .custom_llm_monthly_allowance_in_cents
1144 .map(|allowance| Cents(allowance as u32))
1145 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1146
1147 let spending_for_month = llm_db
1148 .get_user_spending_for_month(user.id, Utc::now())
1149 .await?;
1150
1151 let free_tier_spend = Cents::min(spending_for_month, free_tier);
1152 let monthly_spend = spending_for_month.saturating_sub(free_tier);
1153
1154 Ok(Json(GetMonthlySpendResponse {
1155 monthly_free_tier_spend_in_cents: free_tier_spend.0,
1156 monthly_free_tier_allowance_in_cents: free_tier.0,
1157 monthly_spend_in_cents: monthly_spend.0,
1158 }))
1159}
1160
1161#[derive(Debug, Deserialize)]
1162struct GetCurrentUsageParams {
1163 github_user_id: i32,
1164}
1165
1166#[derive(Debug, Serialize)]
1167struct UsageCounts {
1168 pub used: i32,
1169 pub limit: Option<i32>,
1170 pub remaining: Option<i32>,
1171}
1172
1173#[derive(Debug, Serialize)]
1174struct ModelRequestUsage {
1175 pub model: String,
1176 pub mode: CompletionMode,
1177 pub requests: i32,
1178}
1179
1180#[derive(Debug, Serialize)]
1181struct CurrentUsage {
1182 pub model_requests: UsageCounts,
1183 pub model_request_usage: Vec<ModelRequestUsage>,
1184 pub edit_predictions: UsageCounts,
1185}
1186
1187#[derive(Debug, Default, Serialize)]
1188struct GetCurrentUsageResponse {
1189 pub plan: String,
1190 pub current_usage: Option<CurrentUsage>,
1191}
1192
1193async fn get_current_usage(
1194 Extension(app): Extension<Arc<AppState>>,
1195 Query(params): Query<GetCurrentUsageParams>,
1196) -> Result<Json<GetCurrentUsageResponse>> {
1197 let user = app
1198 .db
1199 .get_user_by_github_user_id(params.github_user_id)
1200 .await?
1201 .ok_or_else(|| anyhow!("user not found"))?;
1202
1203 let feature_flags = app.db.get_user_flags(user.id).await?;
1204 let has_extended_trial = feature_flags
1205 .iter()
1206 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1207
1208 let Some(llm_db) = app.llm_db.clone() else {
1209 return Err(Error::http(
1210 StatusCode::NOT_IMPLEMENTED,
1211 "LLM database not available".into(),
1212 ));
1213 };
1214
1215 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1216 return Ok(Json(GetCurrentUsageResponse::default()));
1217 };
1218
1219 let subscription_period = maybe!({
1220 let period_start_at = subscription.current_period_start_at()?;
1221 let period_end_at = subscription.current_period_end_at()?;
1222
1223 Some((period_start_at, period_end_at))
1224 });
1225
1226 let Some((period_start_at, period_end_at)) = subscription_period else {
1227 return Ok(Json(GetCurrentUsageResponse::default()));
1228 };
1229
1230 let usage = llm_db
1231 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1232 .await?;
1233
1234 let plan = usage
1235 .as_ref()
1236 .map(|usage| usage.plan.into())
1237 .unwrap_or_else(|| {
1238 subscription
1239 .kind
1240 .map(Into::into)
1241 .unwrap_or(zed_llm_client::Plan::Free)
1242 });
1243
1244 let model_requests_limit = match plan.model_requests_limit() {
1245 zed_llm_client::UsageLimit::Limited(limit) => {
1246 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1247 1_000
1248 } else {
1249 limit
1250 };
1251
1252 Some(limit)
1253 }
1254 zed_llm_client::UsageLimit::Unlimited => None,
1255 };
1256
1257 let edit_predictions_limit = match plan.edit_predictions_limit() {
1258 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1259 zed_llm_client::UsageLimit::Unlimited => None,
1260 };
1261
1262 let Some(usage) = usage else {
1263 return Ok(Json(GetCurrentUsageResponse {
1264 plan: plan.as_str().to_string(),
1265 current_usage: Some(CurrentUsage {
1266 model_requests: UsageCounts {
1267 used: 0,
1268 limit: model_requests_limit,
1269 remaining: model_requests_limit,
1270 },
1271 model_request_usage: Vec::new(),
1272 edit_predictions: UsageCounts {
1273 used: 0,
1274 limit: edit_predictions_limit,
1275 remaining: edit_predictions_limit,
1276 },
1277 }),
1278 }));
1279 };
1280
1281 let subscription_usage_meters = llm_db
1282 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1283 .await?;
1284
1285 let model_request_usage = subscription_usage_meters
1286 .into_iter()
1287 .filter_map(|(usage_meter, _usage)| {
1288 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1289
1290 Some(ModelRequestUsage {
1291 model: model.name.clone(),
1292 mode: usage_meter.mode,
1293 requests: usage_meter.requests,
1294 })
1295 })
1296 .collect::<Vec<_>>();
1297
1298 Ok(Json(GetCurrentUsageResponse {
1299 plan: plan.as_str().to_string(),
1300 current_usage: Some(CurrentUsage {
1301 model_requests: UsageCounts {
1302 used: usage.model_requests,
1303 limit: model_requests_limit,
1304 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1305 },
1306 model_request_usage,
1307 edit_predictions: UsageCounts {
1308 used: usage.edit_predictions,
1309 limit: edit_predictions_limit,
1310 remaining: edit_predictions_limit
1311 .map(|limit| (limit - usage.edit_predictions).max(0)),
1312 },
1313 }),
1314 }))
1315}
1316
1317impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1318 fn from(value: SubscriptionStatus) -> Self {
1319 match value {
1320 SubscriptionStatus::Incomplete => Self::Incomplete,
1321 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1322 SubscriptionStatus::Trialing => Self::Trialing,
1323 SubscriptionStatus::Active => Self::Active,
1324 SubscriptionStatus::PastDue => Self::PastDue,
1325 SubscriptionStatus::Canceled => Self::Canceled,
1326 SubscriptionStatus::Unpaid => Self::Unpaid,
1327 SubscriptionStatus::Paused => Self::Paused,
1328 }
1329 }
1330}
1331
1332impl From<CancellationDetailsReason> for StripeCancellationReason {
1333 fn from(value: CancellationDetailsReason) -> Self {
1334 match value {
1335 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1336 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1337 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1338 }
1339 }
1340}
1341
1342/// Finds or creates a billing customer using the provided customer.
1343async fn find_or_create_billing_customer(
1344 app: &Arc<AppState>,
1345 stripe_client: &stripe::Client,
1346 customer_or_id: Expandable<Customer>,
1347) -> anyhow::Result<Option<billing_customer::Model>> {
1348 let customer_id = match &customer_or_id {
1349 Expandable::Id(id) => id,
1350 Expandable::Object(customer) => customer.id.as_ref(),
1351 };
1352
1353 // If we already have a billing customer record associated with the Stripe customer,
1354 // there's nothing more we need to do.
1355 if let Some(billing_customer) = app
1356 .db
1357 .get_billing_customer_by_stripe_customer_id(customer_id)
1358 .await?
1359 {
1360 return Ok(Some(billing_customer));
1361 }
1362
1363 // If all we have is a customer ID, resolve it to a full customer record by
1364 // hitting the Stripe API.
1365 let customer = match customer_or_id {
1366 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1367 Expandable::Object(customer) => *customer,
1368 };
1369
1370 let Some(email) = customer.email else {
1371 return Ok(None);
1372 };
1373
1374 let Some(user) = app.db.get_user_by_email(&email).await? else {
1375 return Ok(None);
1376 };
1377
1378 let billing_customer = app
1379 .db
1380 .create_billing_customer(&CreateBillingCustomerParams {
1381 user_id: user.id,
1382 stripe_customer_id: customer.id.to_string(),
1383 })
1384 .await?;
1385
1386 Ok(Some(billing_customer))
1387}
1388
1389const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1390
1391pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1392 let Some(stripe_billing) = app.stripe_billing.clone() else {
1393 log::warn!("failed to retrieve Stripe billing object");
1394 return;
1395 };
1396 let Some(llm_db) = app.llm_db.clone() else {
1397 log::warn!("failed to retrieve LLM database");
1398 return;
1399 };
1400
1401 let executor = app.executor.clone();
1402 executor.spawn_detached({
1403 let executor = executor.clone();
1404 async move {
1405 loop {
1406 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1407 .await
1408 .context("failed to sync LLM request usage to Stripe")
1409 .trace_err();
1410 executor
1411 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1412 .await;
1413 }
1414 }
1415 });
1416}
1417
1418async fn sync_model_request_usage_with_stripe(
1419 app: &Arc<AppState>,
1420 llm_db: &Arc<LlmDatabase>,
1421 stripe_billing: &Arc<StripeBilling>,
1422) -> anyhow::Result<()> {
1423 let staff_users = app.db.get_staff_users().await?;
1424 let staff_user_ids = staff_users
1425 .iter()
1426 .map(|user| user.id)
1427 .collect::<HashSet<UserId>>();
1428
1429 let usage_meters = llm_db
1430 .get_current_subscription_usage_meters(Utc::now())
1431 .await?;
1432 let usage_meters = usage_meters
1433 .into_iter()
1434 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1435 .collect::<Vec<_>>();
1436 let user_ids = usage_meters
1437 .iter()
1438 .map(|(_, usage)| usage.user_id)
1439 .collect::<HashSet<UserId>>();
1440 let billing_subscriptions = app
1441 .db
1442 .get_active_zed_pro_billing_subscriptions(user_ids)
1443 .await?;
1444
1445 let claude_3_5_sonnet = stripe_billing
1446 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1447 .await?;
1448 let claude_3_7_sonnet = stripe_billing
1449 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1450 .await?;
1451 let claude_3_7_sonnet_max = stripe_billing
1452 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1453 .await?;
1454
1455 for (usage_meter, usage) in usage_meters {
1456 maybe!(async {
1457 let Some((billing_customer, billing_subscription)) =
1458 billing_subscriptions.get(&usage.user_id)
1459 else {
1460 bail!(
1461 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1462 usage.user_id
1463 );
1464 };
1465
1466 let stripe_customer_id = billing_customer
1467 .stripe_customer_id
1468 .parse::<stripe::CustomerId>()
1469 .context("failed to parse Stripe customer ID from database")?;
1470 let stripe_subscription_id = billing_subscription
1471 .stripe_subscription_id
1472 .parse::<stripe::SubscriptionId>()
1473 .context("failed to parse Stripe subscription ID from database")?;
1474
1475 let model = llm_db.model_by_id(usage_meter.model_id)?;
1476
1477 let (price, meter_event_name) = match model.name.as_str() {
1478 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1479 "claude-3-7-sonnet" => match usage_meter.mode {
1480 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1481 CompletionMode::Max => {
1482 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1483 }
1484 },
1485 model_name => {
1486 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1487 }
1488 };
1489
1490 stripe_billing
1491 .subscribe_to_price(&stripe_subscription_id, price)
1492 .await?;
1493 stripe_billing
1494 .bill_model_request_usage(
1495 &stripe_customer_id,
1496 meter_event_name,
1497 usage_meter.requests,
1498 )
1499 .await?;
1500
1501 Ok(())
1502 })
1503 .await
1504 .log_err();
1505 }
1506
1507 Ok(())
1508}