1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
31use crate::rpc::{ResultExt as _, Server};
32use crate::{AppState, Cents, Error, Result};
33use crate::{db::UserId, llm::db::LlmDatabase};
34use crate::{
35 db::{
36 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
37 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
38 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
39 },
40 stripe_billing::StripeBilling,
41};
42
43pub fn router() -> Router {
44 Router::new()
45 .route(
46 "/billing/preferences",
47 get(get_billing_preferences).put(update_billing_preferences),
48 )
49 .route(
50 "/billing/subscriptions",
51 get(list_billing_subscriptions).post(create_billing_subscription),
52 )
53 .route(
54 "/billing/subscriptions/manage",
55 post(manage_billing_subscription),
56 )
57 .route(
58 "/billing/subscriptions/migrate",
59 post(migrate_to_new_billing),
60 )
61 .route("/billing/monthly_spend", get(get_monthly_spend))
62 .route("/billing/usage", get(get_current_usage))
63}
64
65#[derive(Debug, Deserialize)]
66struct GetBillingPreferencesParams {
67 github_user_id: i32,
68}
69
70#[derive(Debug, Serialize)]
71struct BillingPreferencesResponse {
72 max_monthly_llm_usage_spending_in_cents: i32,
73 model_request_overages_enabled: bool,
74 model_request_overages_spend_limit_in_cents: i32,
75}
76
77async fn get_billing_preferences(
78 Extension(app): Extension<Arc<AppState>>,
79 Query(params): Query<GetBillingPreferencesParams>,
80) -> Result<Json<BillingPreferencesResponse>> {
81 let user = app
82 .db
83 .get_user_by_github_user_id(params.github_user_id)
84 .await?
85 .ok_or_else(|| anyhow!("user not found"))?;
86
87 let preferences = app.db.get_billing_preferences(user.id).await?;
88
89 Ok(Json(BillingPreferencesResponse {
90 max_monthly_llm_usage_spending_in_cents: preferences
91 .as_ref()
92 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
93 preferences.max_monthly_llm_usage_spending_in_cents
94 }),
95 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
96 preferences.model_request_overages_enabled
97 }),
98 model_request_overages_spend_limit_in_cents: preferences
99 .as_ref()
100 .map_or(0, |preferences| {
101 preferences.model_request_overages_spend_limit_in_cents
102 }),
103 }))
104}
105
106#[derive(Debug, Deserialize)]
107struct UpdateBillingPreferencesBody {
108 github_user_id: i32,
109 #[serde(default)]
110 max_monthly_llm_usage_spending_in_cents: i32,
111 #[serde(default)]
112 model_request_overages_enabled: bool,
113 #[serde(default)]
114 model_request_overages_spend_limit_in_cents: i32,
115}
116
117async fn update_billing_preferences(
118 Extension(app): Extension<Arc<AppState>>,
119 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
120 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
121) -> Result<Json<BillingPreferencesResponse>> {
122 let user = app
123 .db
124 .get_user_by_github_user_id(body.github_user_id)
125 .await?
126 .ok_or_else(|| anyhow!("user not found"))?;
127
128 let max_monthly_llm_usage_spending_in_cents =
129 body.max_monthly_llm_usage_spending_in_cents.max(0);
130 let model_request_overages_spend_limit_in_cents =
131 body.model_request_overages_spend_limit_in_cents.max(0);
132
133 let billing_preferences =
134 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
135 app.db
136 .update_billing_preferences(
137 user.id,
138 &UpdateBillingPreferencesParams {
139 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
140 max_monthly_llm_usage_spending_in_cents,
141 ),
142 model_request_overages_enabled: ActiveValue::set(
143 body.model_request_overages_enabled,
144 ),
145 model_request_overages_spend_limit_in_cents: ActiveValue::set(
146 model_request_overages_spend_limit_in_cents,
147 ),
148 },
149 )
150 .await?
151 } else {
152 app.db
153 .create_billing_preferences(
154 user.id,
155 &crate::db::CreateBillingPreferencesParams {
156 max_monthly_llm_usage_spending_in_cents,
157 model_request_overages_enabled: body.model_request_overages_enabled,
158 model_request_overages_spend_limit_in_cents,
159 },
160 )
161 .await?
162 };
163
164 SnowflakeRow::new(
165 "Billing Preferences Updated",
166 Some(user.metrics_id),
167 user.admin,
168 None,
169 json!({
170 "user_id": user.id,
171 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
172 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
173 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
174 }),
175 )
176 .write(&app.kinesis_client, &app.config.kinesis_stream)
177 .await
178 .log_err();
179
180 rpc_server.refresh_llm_tokens_for_user(user.id).await;
181
182 Ok(Json(BillingPreferencesResponse {
183 max_monthly_llm_usage_spending_in_cents: billing_preferences
184 .max_monthly_llm_usage_spending_in_cents,
185 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
186 model_request_overages_spend_limit_in_cents: billing_preferences
187 .model_request_overages_spend_limit_in_cents,
188 }))
189}
190
191#[derive(Debug, Deserialize)]
192struct ListBillingSubscriptionsParams {
193 github_user_id: i32,
194}
195
196#[derive(Debug, Serialize)]
197struct BillingSubscriptionJson {
198 id: BillingSubscriptionId,
199 name: String,
200 status: StripeSubscriptionStatus,
201 trial_end_at: Option<String>,
202 cancel_at: Option<String>,
203 /// Whether this subscription can be canceled.
204 is_cancelable: bool,
205}
206
207#[derive(Debug, Serialize)]
208struct ListBillingSubscriptionsResponse {
209 subscriptions: Vec<BillingSubscriptionJson>,
210}
211
212async fn list_billing_subscriptions(
213 Extension(app): Extension<Arc<AppState>>,
214 Query(params): Query<ListBillingSubscriptionsParams>,
215) -> Result<Json<ListBillingSubscriptionsResponse>> {
216 let user = app
217 .db
218 .get_user_by_github_user_id(params.github_user_id)
219 .await?
220 .ok_or_else(|| anyhow!("user not found"))?;
221
222 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
223
224 Ok(Json(ListBillingSubscriptionsResponse {
225 subscriptions: subscriptions
226 .into_iter()
227 .map(|subscription| BillingSubscriptionJson {
228 id: subscription.id,
229 name: match subscription.kind {
230 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
231 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
232 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
233 None => "Zed LLM Usage".to_string(),
234 },
235 status: subscription.stripe_subscription_status,
236 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
237 maybe!({
238 let end_at = subscription.stripe_current_period_end?;
239 let end_at = DateTime::from_timestamp(end_at, 0)?;
240
241 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
242 })
243 } else {
244 None
245 },
246 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
247 cancel_at
248 .and_utc()
249 .to_rfc3339_opts(SecondsFormat::Millis, true)
250 }),
251 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
252 && subscription.stripe_cancel_at.is_none(),
253 })
254 .collect(),
255 }))
256}
257
258#[derive(Debug, Clone, Copy, Deserialize)]
259#[serde(rename_all = "snake_case")]
260enum ProductCode {
261 ZedPro,
262 ZedProTrial,
263 ZedFree,
264}
265
266#[derive(Debug, Deserialize)]
267struct CreateBillingSubscriptionBody {
268 github_user_id: i32,
269 product: Option<ProductCode>,
270}
271
272#[derive(Debug, Serialize)]
273struct CreateBillingSubscriptionResponse {
274 checkout_session_url: String,
275}
276
277/// Initiates a Stripe Checkout session for creating a billing subscription.
278async fn create_billing_subscription(
279 Extension(app): Extension<Arc<AppState>>,
280 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
281) -> Result<Json<CreateBillingSubscriptionResponse>> {
282 let user = app
283 .db
284 .get_user_by_github_user_id(body.github_user_id)
285 .await?
286 .ok_or_else(|| anyhow!("user not found"))?;
287
288 let Some(stripe_client) = app.stripe_client.clone() else {
289 log::error!("failed to retrieve Stripe client");
290 Err(Error::http(
291 StatusCode::NOT_IMPLEMENTED,
292 "not supported".into(),
293 ))?
294 };
295 let Some(stripe_billing) = app.stripe_billing.clone() else {
296 log::error!("failed to retrieve Stripe billing object");
297 Err(Error::http(
298 StatusCode::NOT_IMPLEMENTED,
299 "not supported".into(),
300 ))?
301 };
302 let Some(llm_db) = app.llm_db.clone() else {
303 log::error!("failed to retrieve LLM database");
304 Err(Error::http(
305 StatusCode::NOT_IMPLEMENTED,
306 "not supported".into(),
307 ))?
308 };
309
310 if app.db.has_active_billing_subscription(user.id).await? {
311 return Err(Error::http(
312 StatusCode::CONFLICT,
313 "user already has an active subscription".into(),
314 ));
315 }
316
317 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
318 if let Some(existing_billing_customer) = &existing_billing_customer {
319 if existing_billing_customer.has_overdue_invoices {
320 return Err(Error::http(
321 StatusCode::PAYMENT_REQUIRED,
322 "user has overdue invoices".into(),
323 ));
324 }
325 }
326
327 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
328 CustomerId::from_str(&existing_customer.stripe_customer_id)
329 .context("failed to parse customer ID")?
330 } else {
331 let existing_customer = if let Some(email) = user.email_address.as_deref() {
332 let customers = Customer::list(
333 &stripe_client,
334 &stripe::ListCustomers {
335 email: Some(email),
336 ..Default::default()
337 },
338 )
339 .await?;
340
341 customers.data.first().cloned()
342 } else {
343 None
344 };
345
346 if let Some(existing_customer) = existing_customer {
347 existing_customer.id
348 } else {
349 let customer = Customer::create(
350 &stripe_client,
351 CreateCustomer {
352 email: user.email_address.as_deref(),
353 ..Default::default()
354 },
355 )
356 .await?;
357
358 customer.id
359 }
360 };
361
362 let success_url = format!(
363 "{}/account?checkout_complete=1",
364 app.config.zed_dot_dev_url()
365 );
366
367 let checkout_session_url = match body.product {
368 Some(ProductCode::ZedPro) => {
369 stripe_billing
370 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
371 .await?
372 }
373 Some(ProductCode::ZedProTrial) => {
374 if let Some(existing_billing_customer) = &existing_billing_customer {
375 if existing_billing_customer.trial_started_at.is_some() {
376 return Err(Error::http(
377 StatusCode::FORBIDDEN,
378 "user already used free trial".into(),
379 ));
380 }
381 }
382
383 let feature_flags = app.db.get_user_flags(user.id).await?;
384
385 stripe_billing
386 .checkout_with_zed_pro_trial(
387 customer_id,
388 &user.github_login,
389 feature_flags,
390 &success_url,
391 )
392 .await?
393 }
394 Some(ProductCode::ZedFree) => {
395 stripe_billing
396 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
397 .await?
398 }
399 None => {
400 let default_model = llm_db.model(
401 zed_llm_client::LanguageModelProvider::Anthropic,
402 "claude-3-7-sonnet",
403 )?;
404 let stripe_model = stripe_billing
405 .register_model_for_token_based_usage(default_model)
406 .await?;
407 stripe_billing
408 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
409 .await?
410 }
411 };
412
413 Ok(Json(CreateBillingSubscriptionResponse {
414 checkout_session_url,
415 }))
416}
417
418#[derive(Debug, PartialEq, Deserialize)]
419#[serde(rename_all = "snake_case")]
420enum ManageSubscriptionIntent {
421 /// The user intends to manage their subscription.
422 ///
423 /// This will open the Stripe billing portal without putting the user in a specific flow.
424 ManageSubscription,
425 /// The user intends to upgrade to Zed Pro.
426 UpgradeToPro,
427 /// The user intends to cancel their subscription.
428 Cancel,
429 /// The user intends to stop the cancellation of their subscription.
430 StopCancellation,
431}
432
433#[derive(Debug, Deserialize)]
434struct ManageBillingSubscriptionBody {
435 github_user_id: i32,
436 intent: ManageSubscriptionIntent,
437 /// The ID of the subscription to manage.
438 subscription_id: BillingSubscriptionId,
439}
440
441#[derive(Debug, Serialize)]
442struct ManageBillingSubscriptionResponse {
443 billing_portal_session_url: Option<String>,
444}
445
446/// Initiates a Stripe customer portal session for managing a billing subscription.
447async fn manage_billing_subscription(
448 Extension(app): Extension<Arc<AppState>>,
449 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
450) -> Result<Json<ManageBillingSubscriptionResponse>> {
451 let user = app
452 .db
453 .get_user_by_github_user_id(body.github_user_id)
454 .await?
455 .ok_or_else(|| anyhow!("user not found"))?;
456
457 let Some(stripe_client) = app.stripe_client.clone() else {
458 log::error!("failed to retrieve Stripe client");
459 Err(Error::http(
460 StatusCode::NOT_IMPLEMENTED,
461 "not supported".into(),
462 ))?
463 };
464
465 let Some(stripe_billing) = app.stripe_billing.clone() else {
466 log::error!("failed to retrieve Stripe billing object");
467 Err(Error::http(
468 StatusCode::NOT_IMPLEMENTED,
469 "not supported".into(),
470 ))?
471 };
472
473 let customer = app
474 .db
475 .get_billing_customer_by_user_id(user.id)
476 .await?
477 .ok_or_else(|| anyhow!("billing customer not found"))?;
478 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
479 .context("failed to parse customer ID")?;
480
481 let subscription = app
482 .db
483 .get_billing_subscription_by_id(body.subscription_id)
484 .await?
485 .ok_or_else(|| anyhow!("subscription not found"))?;
486 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
487 .context("failed to parse subscription ID")?;
488
489 if body.intent == ManageSubscriptionIntent::StopCancellation {
490 let updated_stripe_subscription = Subscription::update(
491 &stripe_client,
492 &subscription_id,
493 stripe::UpdateSubscription {
494 cancel_at_period_end: Some(false),
495 ..Default::default()
496 },
497 )
498 .await?;
499
500 app.db
501 .update_billing_subscription(
502 subscription.id,
503 &UpdateBillingSubscriptionParams {
504 stripe_cancel_at: ActiveValue::set(
505 updated_stripe_subscription
506 .cancel_at
507 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
508 .map(|time| time.naive_utc()),
509 ),
510 ..Default::default()
511 },
512 )
513 .await?;
514
515 return Ok(Json(ManageBillingSubscriptionResponse {
516 billing_portal_session_url: None,
517 }));
518 }
519
520 let flow = match body.intent {
521 ManageSubscriptionIntent::ManageSubscription => None,
522 ManageSubscriptionIntent::UpgradeToPro => {
523 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
524 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
525
526 let stripe_subscription =
527 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
528
529 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
530 && stripe_subscription.items.data.iter().any(|item| {
531 item.price
532 .as_ref()
533 .map_or(false, |price| price.id == zed_pro_price_id)
534 });
535 if is_on_zed_pro_trial {
536 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
537 Subscription::update(
538 &stripe_client,
539 &stripe_subscription.id,
540 stripe::UpdateSubscription {
541 trial_end: Some(stripe::Scheduled::now()),
542 ..Default::default()
543 },
544 )
545 .await?;
546
547 return Ok(Json(ManageBillingSubscriptionResponse {
548 billing_portal_session_url: None,
549 }));
550 }
551
552 let subscription_item_to_update = stripe_subscription
553 .items
554 .data
555 .iter()
556 .find_map(|item| {
557 let price = item.price.as_ref()?;
558
559 if price.id == zed_free_price_id {
560 Some(item.id.clone())
561 } else {
562 None
563 }
564 })
565 .ok_or_else(|| anyhow!("No subscription item to update"))?;
566
567 Some(CreateBillingPortalSessionFlowData {
568 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
569 subscription_update_confirm: Some(
570 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
571 subscription: subscription.stripe_subscription_id,
572 items: vec![
573 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
574 id: subscription_item_to_update.to_string(),
575 price: Some(zed_pro_price_id.to_string()),
576 quantity: Some(1),
577 },
578 ],
579 discounts: None,
580 },
581 ),
582 ..Default::default()
583 })
584 }
585 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
586 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
587 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
588 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
589 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
590 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
591 }),
592 ..Default::default()
593 }),
594 subscription_cancel: Some(
595 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
596 subscription: subscription.stripe_subscription_id,
597 retention: None,
598 },
599 ),
600 ..Default::default()
601 }),
602 ManageSubscriptionIntent::StopCancellation => unreachable!(),
603 };
604
605 let mut params = CreateBillingPortalSession::new(customer_id);
606 params.flow_data = flow;
607 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
608 params.return_url = Some(&return_url);
609
610 let session = BillingPortalSession::create(&stripe_client, params).await?;
611
612 Ok(Json(ManageBillingSubscriptionResponse {
613 billing_portal_session_url: Some(session.url),
614 }))
615}
616
617#[derive(Debug, Deserialize)]
618struct MigrateToNewBillingBody {
619 github_user_id: i32,
620}
621
622#[derive(Debug, Serialize)]
623struct MigrateToNewBillingResponse {
624 /// The ID of the subscription that was canceled.
625 canceled_subscription_id: String,
626}
627
628async fn migrate_to_new_billing(
629 Extension(app): Extension<Arc<AppState>>,
630 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
631) -> Result<Json<MigrateToNewBillingResponse>> {
632 let Some(stripe_client) = app.stripe_client.clone() else {
633 log::error!("failed to retrieve Stripe client");
634 Err(Error::http(
635 StatusCode::NOT_IMPLEMENTED,
636 "not supported".into(),
637 ))?
638 };
639
640 let user = app
641 .db
642 .get_user_by_github_user_id(body.github_user_id)
643 .await?
644 .ok_or_else(|| anyhow!("user not found"))?;
645
646 let old_billing_subscriptions_by_user = app
647 .db
648 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
649 .await?;
650
651 let Some((_billing_customer, billing_subscription)) =
652 old_billing_subscriptions_by_user.get(&user.id)
653 else {
654 return Err(Error::http(
655 StatusCode::NOT_FOUND,
656 "No active billing subscriptions to migrate".into(),
657 ));
658 };
659
660 let stripe_subscription_id = billing_subscription
661 .stripe_subscription_id
662 .parse::<stripe::SubscriptionId>()
663 .context("failed to parse Stripe subscription ID from database")?;
664
665 Subscription::cancel(
666 &stripe_client,
667 &stripe_subscription_id,
668 stripe::CancelSubscription {
669 invoice_now: Some(true),
670 ..Default::default()
671 },
672 )
673 .await?;
674
675 let feature_flags = app.db.list_feature_flags().await?;
676
677 for feature_flag in ["new-billing", "assistant2"] {
678 let already_in_feature_flag = feature_flags.iter().any(|flag| flag.flag == feature_flag);
679 if already_in_feature_flag {
680 continue;
681 }
682
683 let feature_flag = feature_flags
684 .iter()
685 .find(|flag| flag.flag == feature_flag)
686 .context("failed to find feature flag: {feature_flag:?}")?;
687
688 app.db.add_user_flag(user.id, feature_flag.id).await?;
689 }
690
691 Ok(Json(MigrateToNewBillingResponse {
692 canceled_subscription_id: stripe_subscription_id.to_string(),
693 }))
694}
695
696/// The amount of time we wait in between each poll of Stripe events.
697///
698/// This value should strike a balance between:
699/// 1. Being short enough that we update quickly when something in Stripe changes
700/// 2. Being long enough that we don't eat into our rate limits.
701///
702/// As a point of reference, the Sequin folks say they have this at **500ms**:
703///
704/// > We poll the Stripe /events endpoint every 500ms per account
705/// >
706/// > — https://blog.sequinstream.com/events-not-webhooks/
707const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
708
709/// The maximum number of events to return per page.
710///
711/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
712///
713/// > Limit can range between 1 and 100, and the default is 10.
714const EVENTS_LIMIT_PER_PAGE: u64 = 100;
715
716/// The number of pages consisting entirely of already-processed events that we
717/// will see before we stop retrieving events.
718///
719/// This is used to prevent over-fetching the Stripe events API for events we've
720/// already seen and processed.
721const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
722
723/// Polls the Stripe events API periodically to reconcile the records in our
724/// database with the data in Stripe.
725pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
726 let Some(stripe_client) = app.stripe_client.clone() else {
727 log::warn!("failed to retrieve Stripe client");
728 return;
729 };
730
731 let executor = app.executor.clone();
732 executor.spawn_detached({
733 let executor = executor.clone();
734 async move {
735 loop {
736 poll_stripe_events(&app, &rpc_server, &stripe_client)
737 .await
738 .log_err();
739
740 executor.sleep(POLL_EVENTS_INTERVAL).await;
741 }
742 }
743 });
744}
745
746async fn poll_stripe_events(
747 app: &Arc<AppState>,
748 rpc_server: &Arc<Server>,
749 stripe_client: &stripe::Client,
750) -> anyhow::Result<()> {
751 fn event_type_to_string(event_type: EventType) -> String {
752 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
753 // so we need to unquote it.
754 event_type.to_string().trim_matches('"').to_string()
755 }
756
757 let event_types = [
758 EventType::CustomerCreated,
759 EventType::CustomerUpdated,
760 EventType::CustomerSubscriptionCreated,
761 EventType::CustomerSubscriptionUpdated,
762 EventType::CustomerSubscriptionPaused,
763 EventType::CustomerSubscriptionResumed,
764 EventType::CustomerSubscriptionDeleted,
765 ]
766 .into_iter()
767 .map(event_type_to_string)
768 .collect::<Vec<_>>();
769
770 let mut pages_of_already_processed_events = 0;
771 let mut unprocessed_events = Vec::new();
772
773 log::info!(
774 "Stripe events: starting retrieval for {}",
775 event_types.join(", ")
776 );
777 let mut params = ListEvents::new();
778 params.types = Some(event_types.clone());
779 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
780
781 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
782 .await?
783 .paginate(params);
784
785 loop {
786 let processed_event_ids = {
787 let event_ids = event_pages
788 .page
789 .data
790 .iter()
791 .map(|event| event.id.as_str())
792 .collect::<Vec<_>>();
793 app.db
794 .get_processed_stripe_events_by_event_ids(&event_ids)
795 .await?
796 .into_iter()
797 .map(|event| event.stripe_event_id)
798 .collect::<Vec<_>>()
799 };
800
801 let mut processed_events_in_page = 0;
802 let events_in_page = event_pages.page.data.len();
803 for event in &event_pages.page.data {
804 if processed_event_ids.contains(&event.id.to_string()) {
805 processed_events_in_page += 1;
806 log::debug!("Stripe events: already processed '{}', skipping", event.id);
807 } else {
808 unprocessed_events.push(event.clone());
809 }
810 }
811
812 if processed_events_in_page == events_in_page {
813 pages_of_already_processed_events += 1;
814 }
815
816 if event_pages.page.has_more {
817 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
818 {
819 log::info!(
820 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
821 );
822 break;
823 } else {
824 log::info!("Stripe events: retrieving next page");
825 event_pages = event_pages.next(&stripe_client).await?;
826 }
827 } else {
828 break;
829 }
830 }
831
832 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
833
834 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
835 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
836
837 for event in unprocessed_events {
838 let event_id = event.id.clone();
839 let processed_event_params = CreateProcessedStripeEventParams {
840 stripe_event_id: event.id.to_string(),
841 stripe_event_type: event_type_to_string(event.type_),
842 stripe_event_created_timestamp: event.created,
843 };
844
845 // If the event has happened too far in the past, we don't want to
846 // process it and risk overwriting other more-recent updates.
847 //
848 // 1 day was chosen arbitrarily. This could be made longer or shorter.
849 let one_day = Duration::from_secs(24 * 60 * 60);
850 let a_day_ago = Utc::now() - one_day;
851 if a_day_ago.timestamp() > event.created {
852 log::info!(
853 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
854 event_id
855 );
856 app.db
857 .create_processed_stripe_event(&processed_event_params)
858 .await?;
859
860 return Ok(());
861 }
862
863 let process_result = match event.type_ {
864 EventType::CustomerCreated | EventType::CustomerUpdated => {
865 handle_customer_event(app, stripe_client, event).await
866 }
867 EventType::CustomerSubscriptionCreated
868 | EventType::CustomerSubscriptionUpdated
869 | EventType::CustomerSubscriptionPaused
870 | EventType::CustomerSubscriptionResumed
871 | EventType::CustomerSubscriptionDeleted => {
872 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
873 }
874 _ => Ok(()),
875 };
876
877 if let Some(()) = process_result
878 .with_context(|| format!("failed to process event {event_id} successfully"))
879 .log_err()
880 {
881 app.db
882 .create_processed_stripe_event(&processed_event_params)
883 .await?;
884 }
885 }
886
887 Ok(())
888}
889
890async fn handle_customer_event(
891 app: &Arc<AppState>,
892 _stripe_client: &stripe::Client,
893 event: stripe::Event,
894) -> anyhow::Result<()> {
895 let EventObject::Customer(customer) = event.data.object else {
896 bail!("unexpected event payload for {}", event.id);
897 };
898
899 log::info!("handling Stripe {} event: {}", event.type_, event.id);
900
901 let Some(email) = customer.email else {
902 log::info!("Stripe customer has no email: skipping");
903 return Ok(());
904 };
905
906 let Some(user) = app.db.get_user_by_email(&email).await? else {
907 log::info!("no user found for email: skipping");
908 return Ok(());
909 };
910
911 if let Some(existing_customer) = app
912 .db
913 .get_billing_customer_by_stripe_customer_id(&customer.id)
914 .await?
915 {
916 app.db
917 .update_billing_customer(
918 existing_customer.id,
919 &UpdateBillingCustomerParams {
920 // For now we just leave the information as-is, as it is not
921 // likely to change.
922 ..Default::default()
923 },
924 )
925 .await?;
926 } else {
927 app.db
928 .create_billing_customer(&CreateBillingCustomerParams {
929 user_id: user.id,
930 stripe_customer_id: customer.id.to_string(),
931 })
932 .await?;
933 }
934
935 Ok(())
936}
937
938async fn handle_customer_subscription_event(
939 app: &Arc<AppState>,
940 rpc_server: &Arc<Server>,
941 stripe_client: &stripe::Client,
942 event: stripe::Event,
943) -> anyhow::Result<()> {
944 let EventObject::Subscription(subscription) = event.data.object else {
945 bail!("unexpected event payload for {}", event.id);
946 };
947
948 log::info!("handling Stripe {} event: {}", event.type_, event.id);
949
950 let subscription_kind = maybe!(async {
951 let stripe_billing = app.stripe_billing.clone()?;
952
953 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
954 let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
955
956 subscription.items.data.iter().find_map(|item| {
957 let price = item.price.as_ref()?;
958
959 if price.id == zed_pro_price_id {
960 Some(if subscription.status == SubscriptionStatus::Trialing {
961 SubscriptionKind::ZedProTrial
962 } else {
963 SubscriptionKind::ZedPro
964 })
965 } else if price.id == zed_free_price_id {
966 Some(SubscriptionKind::ZedFree)
967 } else {
968 None
969 }
970 })
971 })
972 .await;
973
974 let billing_customer =
975 find_or_create_billing_customer(app, stripe_client, subscription.customer)
976 .await?
977 .ok_or_else(|| anyhow!("billing customer not found"))?;
978
979 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
980 if subscription.status == SubscriptionStatus::Trialing {
981 let current_period_start =
982 DateTime::from_timestamp(subscription.current_period_start, 0)
983 .ok_or_else(|| anyhow!("No trial subscription period start"))?;
984
985 app.db
986 .update_billing_customer(
987 billing_customer.id,
988 &UpdateBillingCustomerParams {
989 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
990 ..Default::default()
991 },
992 )
993 .await?;
994 }
995 }
996
997 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
998 && subscription
999 .cancellation_details
1000 .as_ref()
1001 .and_then(|details| details.reason)
1002 .map_or(false, |reason| {
1003 reason == CancellationDetailsReason::PaymentFailed
1004 });
1005
1006 if was_canceled_due_to_payment_failure {
1007 app.db
1008 .update_billing_customer(
1009 billing_customer.id,
1010 &UpdateBillingCustomerParams {
1011 has_overdue_invoices: ActiveValue::set(true),
1012 ..Default::default()
1013 },
1014 )
1015 .await?;
1016 }
1017
1018 if let Some(existing_subscription) = app
1019 .db
1020 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1021 .await?
1022 {
1023 let llm_db = app
1024 .llm_db
1025 .clone()
1026 .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
1027
1028 let new_period_start_at =
1029 chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
1030 .ok_or_else(|| anyhow!("No subscription period start"))?;
1031 let new_period_end_at =
1032 chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
1033 .ok_or_else(|| anyhow!("No subscription period end"))?;
1034
1035 llm_db
1036 .transfer_existing_subscription_usage(
1037 billing_customer.user_id,
1038 &existing_subscription,
1039 subscription_kind,
1040 new_period_start_at,
1041 new_period_end_at,
1042 )
1043 .await?;
1044
1045 app.db
1046 .update_billing_subscription(
1047 existing_subscription.id,
1048 &UpdateBillingSubscriptionParams {
1049 billing_customer_id: ActiveValue::set(billing_customer.id),
1050 kind: ActiveValue::set(subscription_kind),
1051 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1052 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1053 stripe_cancel_at: ActiveValue::set(
1054 subscription
1055 .cancel_at
1056 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1057 .map(|time| time.naive_utc()),
1058 ),
1059 stripe_cancellation_reason: ActiveValue::set(
1060 subscription
1061 .cancellation_details
1062 .and_then(|details| details.reason)
1063 .map(|reason| reason.into()),
1064 ),
1065 stripe_current_period_start: ActiveValue::set(Some(
1066 subscription.current_period_start,
1067 )),
1068 stripe_current_period_end: ActiveValue::set(Some(
1069 subscription.current_period_end,
1070 )),
1071 },
1072 )
1073 .await?;
1074 } else {
1075 // If the user already has an active billing subscription, ignore the
1076 // event and return an `Ok` to signal that it was processed
1077 // successfully.
1078 //
1079 // There is the possibility that this could cause us to not create a
1080 // subscription in the following scenario:
1081 //
1082 // 1. User has an active subscription A
1083 // 2. User cancels subscription A
1084 // 3. User creates a new subscription B
1085 // 4. We process the new subscription B before the cancellation of subscription A
1086 // 5. User ends up with no subscriptions
1087 //
1088 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1089 if app
1090 .db
1091 .has_active_billing_subscription(billing_customer.user_id)
1092 .await?
1093 {
1094 log::info!(
1095 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1096 user_id = billing_customer.user_id,
1097 subscription_id = subscription.id
1098 );
1099 return Ok(());
1100 }
1101
1102 app.db
1103 .create_billing_subscription(&CreateBillingSubscriptionParams {
1104 billing_customer_id: billing_customer.id,
1105 kind: subscription_kind,
1106 stripe_subscription_id: subscription.id.to_string(),
1107 stripe_subscription_status: subscription.status.into(),
1108 stripe_cancellation_reason: subscription
1109 .cancellation_details
1110 .and_then(|details| details.reason)
1111 .map(|reason| reason.into()),
1112 stripe_current_period_start: Some(subscription.current_period_start),
1113 stripe_current_period_end: Some(subscription.current_period_end),
1114 })
1115 .await?;
1116 }
1117
1118 // When the user's subscription changes, we want to refresh their LLM tokens
1119 // to either grant/revoke access.
1120 rpc_server
1121 .refresh_llm_tokens_for_user(billing_customer.user_id)
1122 .await;
1123
1124 Ok(())
1125}
1126
1127#[derive(Debug, Deserialize)]
1128struct GetMonthlySpendParams {
1129 github_user_id: i32,
1130}
1131
1132#[derive(Debug, Serialize)]
1133struct GetMonthlySpendResponse {
1134 monthly_free_tier_spend_in_cents: u32,
1135 monthly_free_tier_allowance_in_cents: u32,
1136 monthly_spend_in_cents: u32,
1137}
1138
1139async fn get_monthly_spend(
1140 Extension(app): Extension<Arc<AppState>>,
1141 Query(params): Query<GetMonthlySpendParams>,
1142) -> Result<Json<GetMonthlySpendResponse>> {
1143 let user = app
1144 .db
1145 .get_user_by_github_user_id(params.github_user_id)
1146 .await?
1147 .ok_or_else(|| anyhow!("user not found"))?;
1148
1149 let Some(llm_db) = app.llm_db.clone() else {
1150 return Err(Error::http(
1151 StatusCode::NOT_IMPLEMENTED,
1152 "LLM database not available".into(),
1153 ));
1154 };
1155
1156 let free_tier = user
1157 .custom_llm_monthly_allowance_in_cents
1158 .map(|allowance| Cents(allowance as u32))
1159 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1160
1161 let spending_for_month = llm_db
1162 .get_user_spending_for_month(user.id, Utc::now())
1163 .await?;
1164
1165 let free_tier_spend = Cents::min(spending_for_month, free_tier);
1166 let monthly_spend = spending_for_month.saturating_sub(free_tier);
1167
1168 Ok(Json(GetMonthlySpendResponse {
1169 monthly_free_tier_spend_in_cents: free_tier_spend.0,
1170 monthly_free_tier_allowance_in_cents: free_tier.0,
1171 monthly_spend_in_cents: monthly_spend.0,
1172 }))
1173}
1174
1175#[derive(Debug, Deserialize)]
1176struct GetCurrentUsageParams {
1177 github_user_id: i32,
1178}
1179
1180#[derive(Debug, Serialize)]
1181struct UsageCounts {
1182 pub used: i32,
1183 pub limit: Option<i32>,
1184 pub remaining: Option<i32>,
1185}
1186
1187#[derive(Debug, Serialize)]
1188struct ModelRequestUsage {
1189 pub model: String,
1190 pub mode: CompletionMode,
1191 pub requests: i32,
1192}
1193
1194#[derive(Debug, Serialize)]
1195struct GetCurrentUsageResponse {
1196 pub model_requests: UsageCounts,
1197 pub model_request_usage: Vec<ModelRequestUsage>,
1198 pub edit_predictions: UsageCounts,
1199}
1200
1201async fn get_current_usage(
1202 Extension(app): Extension<Arc<AppState>>,
1203 Query(params): Query<GetCurrentUsageParams>,
1204) -> Result<Json<GetCurrentUsageResponse>> {
1205 let user = app
1206 .db
1207 .get_user_by_github_user_id(params.github_user_id)
1208 .await?
1209 .ok_or_else(|| anyhow!("user not found"))?;
1210
1211 let Some(llm_db) = app.llm_db.clone() else {
1212 return Err(Error::http(
1213 StatusCode::NOT_IMPLEMENTED,
1214 "LLM database not available".into(),
1215 ));
1216 };
1217
1218 let empty_usage = GetCurrentUsageResponse {
1219 model_requests: UsageCounts {
1220 used: 0,
1221 limit: Some(0),
1222 remaining: Some(0),
1223 },
1224 model_request_usage: Vec::new(),
1225 edit_predictions: UsageCounts {
1226 used: 0,
1227 limit: Some(0),
1228 remaining: Some(0),
1229 },
1230 };
1231
1232 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1233 return Ok(Json(empty_usage));
1234 };
1235
1236 let subscription_period = maybe!({
1237 let period_start_at = subscription.current_period_start_at()?;
1238 let period_end_at = subscription.current_period_end_at()?;
1239
1240 Some((period_start_at, period_end_at))
1241 });
1242
1243 let Some((period_start_at, period_end_at)) = subscription_period else {
1244 return Ok(Json(empty_usage));
1245 };
1246
1247 let usage = llm_db
1248 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1249 .await?;
1250 let Some(usage) = usage else {
1251 return Ok(Json(empty_usage));
1252 };
1253
1254 let plan = match usage.plan {
1255 SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1256 SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1257 SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1258 };
1259
1260 let model_requests_limit = match plan.model_requests_limit() {
1261 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1262 zed_llm_client::UsageLimit::Unlimited => None,
1263 };
1264 let edit_prediction_limit = match plan.edit_predictions_limit() {
1265 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1266 zed_llm_client::UsageLimit::Unlimited => None,
1267 };
1268
1269 let subscription_usage_meters = llm_db
1270 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1271 .await?;
1272
1273 let model_request_usage = subscription_usage_meters
1274 .into_iter()
1275 .filter_map(|(usage_meter, _usage)| {
1276 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1277
1278 Some(ModelRequestUsage {
1279 model: model.name.clone(),
1280 mode: usage_meter.mode,
1281 requests: usage_meter.requests,
1282 })
1283 })
1284 .collect::<Vec<_>>();
1285
1286 Ok(Json(GetCurrentUsageResponse {
1287 model_requests: UsageCounts {
1288 used: usage.model_requests,
1289 limit: model_requests_limit,
1290 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1291 },
1292 model_request_usage,
1293 edit_predictions: UsageCounts {
1294 used: usage.edit_predictions,
1295 limit: edit_prediction_limit,
1296 remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1297 },
1298 }))
1299}
1300
1301impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1302 fn from(value: SubscriptionStatus) -> Self {
1303 match value {
1304 SubscriptionStatus::Incomplete => Self::Incomplete,
1305 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1306 SubscriptionStatus::Trialing => Self::Trialing,
1307 SubscriptionStatus::Active => Self::Active,
1308 SubscriptionStatus::PastDue => Self::PastDue,
1309 SubscriptionStatus::Canceled => Self::Canceled,
1310 SubscriptionStatus::Unpaid => Self::Unpaid,
1311 SubscriptionStatus::Paused => Self::Paused,
1312 }
1313 }
1314}
1315
1316impl From<CancellationDetailsReason> for StripeCancellationReason {
1317 fn from(value: CancellationDetailsReason) -> Self {
1318 match value {
1319 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1320 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1321 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1322 }
1323 }
1324}
1325
1326/// Finds or creates a billing customer using the provided customer.
1327async fn find_or_create_billing_customer(
1328 app: &Arc<AppState>,
1329 stripe_client: &stripe::Client,
1330 customer_or_id: Expandable<Customer>,
1331) -> anyhow::Result<Option<billing_customer::Model>> {
1332 let customer_id = match &customer_or_id {
1333 Expandable::Id(id) => id,
1334 Expandable::Object(customer) => customer.id.as_ref(),
1335 };
1336
1337 // If we already have a billing customer record associated with the Stripe customer,
1338 // there's nothing more we need to do.
1339 if let Some(billing_customer) = app
1340 .db
1341 .get_billing_customer_by_stripe_customer_id(customer_id)
1342 .await?
1343 {
1344 return Ok(Some(billing_customer));
1345 }
1346
1347 // If all we have is a customer ID, resolve it to a full customer record by
1348 // hitting the Stripe API.
1349 let customer = match customer_or_id {
1350 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1351 Expandable::Object(customer) => *customer,
1352 };
1353
1354 let Some(email) = customer.email else {
1355 return Ok(None);
1356 };
1357
1358 let Some(user) = app.db.get_user_by_email(&email).await? else {
1359 return Ok(None);
1360 };
1361
1362 let billing_customer = app
1363 .db
1364 .create_billing_customer(&CreateBillingCustomerParams {
1365 user_id: user.id,
1366 stripe_customer_id: customer.id.to_string(),
1367 })
1368 .await?;
1369
1370 Ok(Some(billing_customer))
1371}
1372
1373const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1374
1375pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
1376 let Some(stripe_billing) = app.stripe_billing.clone() else {
1377 log::warn!("failed to retrieve Stripe billing object");
1378 return;
1379 };
1380 let Some(llm_db) = app.llm_db.clone() else {
1381 log::warn!("failed to retrieve LLM database");
1382 return;
1383 };
1384
1385 let executor = app.executor.clone();
1386 executor.spawn_detached({
1387 let executor = executor.clone();
1388 async move {
1389 loop {
1390 sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
1391 .await
1392 .context("failed to sync LLM usage to Stripe")
1393 .trace_err();
1394 executor
1395 .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
1396 .await;
1397 }
1398 }
1399 });
1400}
1401
1402async fn sync_token_usage_with_stripe(
1403 app: &Arc<AppState>,
1404 llm_db: &Arc<LlmDatabase>,
1405 stripe_billing: &Arc<StripeBilling>,
1406) -> anyhow::Result<()> {
1407 let events = llm_db.get_billing_events().await?;
1408 let user_ids = events
1409 .iter()
1410 .map(|(event, _)| event.user_id)
1411 .collect::<HashSet<UserId>>();
1412 let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1413
1414 for (event, model) in events {
1415 let Some((stripe_db_customer, stripe_db_subscription)) =
1416 stripe_subscriptions.get(&event.user_id)
1417 else {
1418 tracing::warn!(
1419 user_id = event.user_id.0,
1420 "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1421 );
1422 continue;
1423 };
1424 let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1425 .stripe_subscription_id
1426 .parse()
1427 .context("failed to parse stripe subscription id from db")?;
1428 let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1429 .stripe_customer_id
1430 .parse()
1431 .context("failed to parse stripe customer id from db")?;
1432
1433 let stripe_model = stripe_billing
1434 .register_model_for_token_based_usage(&model)
1435 .await?;
1436 stripe_billing
1437 .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1438 .await?;
1439 stripe_billing
1440 .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
1441 .await?;
1442 llm_db.consume_billing_event(event.id).await?;
1443 }
1444
1445 Ok(())
1446}
1447
1448const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1449
1450pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1451 let Some(stripe_billing) = app.stripe_billing.clone() else {
1452 log::warn!("failed to retrieve Stripe billing object");
1453 return;
1454 };
1455 let Some(llm_db) = app.llm_db.clone() else {
1456 log::warn!("failed to retrieve LLM database");
1457 return;
1458 };
1459
1460 let executor = app.executor.clone();
1461 executor.spawn_detached({
1462 let executor = executor.clone();
1463 async move {
1464 loop {
1465 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1466 .await
1467 .context("failed to sync LLM request usage to Stripe")
1468 .trace_err();
1469 executor
1470 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1471 .await;
1472 }
1473 }
1474 });
1475}
1476
1477async fn sync_model_request_usage_with_stripe(
1478 app: &Arc<AppState>,
1479 llm_db: &Arc<LlmDatabase>,
1480 stripe_billing: &Arc<StripeBilling>,
1481) -> anyhow::Result<()> {
1482 let usage_meters = llm_db
1483 .get_current_subscription_usage_meters(Utc::now())
1484 .await?;
1485 let user_ids = usage_meters
1486 .iter()
1487 .map(|(_, usage)| usage.user_id)
1488 .collect::<HashSet<UserId>>();
1489 let billing_subscriptions = app
1490 .db
1491 .get_active_zed_pro_billing_subscriptions(user_ids)
1492 .await?;
1493
1494 let claude_3_5_sonnet = stripe_billing
1495 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1496 .await?;
1497 let claude_3_7_sonnet = stripe_billing
1498 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1499 .await?;
1500 let claude_3_7_sonnet_max = stripe_billing
1501 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1502 .await?;
1503
1504 for (usage_meter, usage) in usage_meters {
1505 maybe!(async {
1506 let Some((billing_customer, billing_subscription)) =
1507 billing_subscriptions.get(&usage.user_id)
1508 else {
1509 bail!(
1510 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1511 usage.user_id
1512 );
1513 };
1514
1515 let stripe_customer_id = billing_customer
1516 .stripe_customer_id
1517 .parse::<stripe::CustomerId>()
1518 .context("failed to parse Stripe customer ID from database")?;
1519 let stripe_subscription_id = billing_subscription
1520 .stripe_subscription_id
1521 .parse::<stripe::SubscriptionId>()
1522 .context("failed to parse Stripe subscription ID from database")?;
1523
1524 let model = llm_db.model_by_id(usage_meter.model_id)?;
1525
1526 let (price, meter_event_name) = match model.name.as_str() {
1527 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1528 "claude-3-7-sonnet" => match usage_meter.mode {
1529 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1530 CompletionMode::Max => {
1531 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1532 }
1533 },
1534 model_name => {
1535 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1536 }
1537 };
1538
1539 stripe_billing
1540 .subscribe_to_price(&stripe_subscription_id, price)
1541 .await?;
1542 stripe_billing
1543 .bill_model_request_usage(
1544 &stripe_customer_id,
1545 meter_event_name,
1546 usage_meter.requests,
1547 )
1548 .await?;
1549
1550 Ok(())
1551 })
1552 .await
1553 .log_err();
1554 }
1555
1556 Ok(())
1557}