1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
22 SubscriptionStatus,
23};
24use util::{ResultExt, maybe};
25
26use crate::api::events::SnowflakeRow;
27use crate::db::billing_subscription::{
28 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
29};
30use crate::llm::db::subscription_usage_meter::CompletionMode;
31use crate::llm::{
32 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
33};
34use crate::rpc::{ResultExt as _, Server};
35use crate::{AppState, Cents, Error, Result};
36use crate::{db::UserId, llm::db::LlmDatabase};
37use crate::{
38 db::{
39 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
40 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
41 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
42 },
43 stripe_billing::StripeBilling,
44};
45
46pub fn router() -> Router {
47 Router::new()
48 .route(
49 "/billing/preferences",
50 get(get_billing_preferences).put(update_billing_preferences),
51 )
52 .route(
53 "/billing/subscriptions",
54 get(list_billing_subscriptions).post(create_billing_subscription),
55 )
56 .route(
57 "/billing/subscriptions/manage",
58 post(manage_billing_subscription),
59 )
60 .route(
61 "/billing/subscriptions/migrate",
62 post(migrate_to_new_billing),
63 )
64 .route(
65 "/billing/subscriptions/sync",
66 post(sync_billing_subscription),
67 )
68 .route("/billing/monthly_spend", get(get_monthly_spend))
69 .route("/billing/usage", get(get_current_usage))
70}
71
72#[derive(Debug, Deserialize)]
73struct GetBillingPreferencesParams {
74 github_user_id: i32,
75}
76
77#[derive(Debug, Serialize)]
78struct BillingPreferencesResponse {
79 trial_started_at: Option<String>,
80 max_monthly_llm_usage_spending_in_cents: i32,
81 model_request_overages_enabled: bool,
82 model_request_overages_spend_limit_in_cents: i32,
83}
84
85async fn get_billing_preferences(
86 Extension(app): Extension<Arc<AppState>>,
87 Query(params): Query<GetBillingPreferencesParams>,
88) -> Result<Json<BillingPreferencesResponse>> {
89 let user = app
90 .db
91 .get_user_by_github_user_id(params.github_user_id)
92 .await?
93 .ok_or_else(|| anyhow!("user not found"))?;
94
95 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
96 let preferences = app.db.get_billing_preferences(user.id).await?;
97
98 Ok(Json(BillingPreferencesResponse {
99 trial_started_at: billing_customer
100 .and_then(|billing_customer| billing_customer.trial_started_at)
101 .map(|trial_started_at| {
102 trial_started_at
103 .and_utc()
104 .to_rfc3339_opts(SecondsFormat::Millis, true)
105 }),
106 max_monthly_llm_usage_spending_in_cents: preferences
107 .as_ref()
108 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
109 preferences.max_monthly_llm_usage_spending_in_cents
110 }),
111 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
112 preferences.model_request_overages_enabled
113 }),
114 model_request_overages_spend_limit_in_cents: preferences
115 .as_ref()
116 .map_or(0, |preferences| {
117 preferences.model_request_overages_spend_limit_in_cents
118 }),
119 }))
120}
121
122#[derive(Debug, Deserialize)]
123struct UpdateBillingPreferencesBody {
124 github_user_id: i32,
125 #[serde(default)]
126 max_monthly_llm_usage_spending_in_cents: i32,
127 #[serde(default)]
128 model_request_overages_enabled: bool,
129 #[serde(default)]
130 model_request_overages_spend_limit_in_cents: i32,
131}
132
133async fn update_billing_preferences(
134 Extension(app): Extension<Arc<AppState>>,
135 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
136 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
137) -> Result<Json<BillingPreferencesResponse>> {
138 let user = app
139 .db
140 .get_user_by_github_user_id(body.github_user_id)
141 .await?
142 .ok_or_else(|| anyhow!("user not found"))?;
143
144 let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
145
146 let max_monthly_llm_usage_spending_in_cents =
147 body.max_monthly_llm_usage_spending_in_cents.max(0);
148 let model_request_overages_spend_limit_in_cents =
149 body.model_request_overages_spend_limit_in_cents.max(0);
150
151 let billing_preferences =
152 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
153 app.db
154 .update_billing_preferences(
155 user.id,
156 &UpdateBillingPreferencesParams {
157 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
158 max_monthly_llm_usage_spending_in_cents,
159 ),
160 model_request_overages_enabled: ActiveValue::set(
161 body.model_request_overages_enabled,
162 ),
163 model_request_overages_spend_limit_in_cents: ActiveValue::set(
164 model_request_overages_spend_limit_in_cents,
165 ),
166 },
167 )
168 .await?
169 } else {
170 app.db
171 .create_billing_preferences(
172 user.id,
173 &crate::db::CreateBillingPreferencesParams {
174 max_monthly_llm_usage_spending_in_cents,
175 model_request_overages_enabled: body.model_request_overages_enabled,
176 model_request_overages_spend_limit_in_cents,
177 },
178 )
179 .await?
180 };
181
182 SnowflakeRow::new(
183 "Billing Preferences Updated",
184 Some(user.metrics_id),
185 user.admin,
186 None,
187 json!({
188 "user_id": user.id,
189 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
190 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
191 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
192 }),
193 )
194 .write(&app.kinesis_client, &app.config.kinesis_stream)
195 .await
196 .log_err();
197
198 rpc_server.refresh_llm_tokens_for_user(user.id).await;
199
200 Ok(Json(BillingPreferencesResponse {
201 trial_started_at: billing_customer
202 .and_then(|billing_customer| billing_customer.trial_started_at)
203 .map(|trial_started_at| {
204 trial_started_at
205 .and_utc()
206 .to_rfc3339_opts(SecondsFormat::Millis, true)
207 }),
208 max_monthly_llm_usage_spending_in_cents: billing_preferences
209 .max_monthly_llm_usage_spending_in_cents,
210 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
211 model_request_overages_spend_limit_in_cents: billing_preferences
212 .model_request_overages_spend_limit_in_cents,
213 }))
214}
215
216#[derive(Debug, Deserialize)]
217struct ListBillingSubscriptionsParams {
218 github_user_id: i32,
219}
220
221#[derive(Debug, Serialize)]
222struct BillingSubscriptionJson {
223 id: BillingSubscriptionId,
224 name: String,
225 status: StripeSubscriptionStatus,
226 trial_end_at: Option<String>,
227 cancel_at: Option<String>,
228 /// Whether this subscription can be canceled.
229 is_cancelable: bool,
230}
231
232#[derive(Debug, Serialize)]
233struct ListBillingSubscriptionsResponse {
234 subscriptions: Vec<BillingSubscriptionJson>,
235}
236
237async fn list_billing_subscriptions(
238 Extension(app): Extension<Arc<AppState>>,
239 Query(params): Query<ListBillingSubscriptionsParams>,
240) -> Result<Json<ListBillingSubscriptionsResponse>> {
241 let user = app
242 .db
243 .get_user_by_github_user_id(params.github_user_id)
244 .await?
245 .ok_or_else(|| anyhow!("user not found"))?;
246
247 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
248
249 Ok(Json(ListBillingSubscriptionsResponse {
250 subscriptions: subscriptions
251 .into_iter()
252 .map(|subscription| BillingSubscriptionJson {
253 id: subscription.id,
254 name: match subscription.kind {
255 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
256 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
257 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
258 None => "Zed LLM Usage".to_string(),
259 },
260 status: subscription.stripe_subscription_status,
261 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
262 maybe!({
263 let end_at = subscription.stripe_current_period_end?;
264 let end_at = DateTime::from_timestamp(end_at, 0)?;
265
266 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
267 })
268 } else {
269 None
270 },
271 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
272 cancel_at
273 .and_utc()
274 .to_rfc3339_opts(SecondsFormat::Millis, true)
275 }),
276 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
277 && subscription.stripe_cancel_at.is_none(),
278 })
279 .collect(),
280 }))
281}
282
283#[derive(Debug, Clone, Copy, Deserialize)]
284#[serde(rename_all = "snake_case")]
285enum ProductCode {
286 ZedPro,
287 ZedProTrial,
288 ZedFree,
289}
290
291#[derive(Debug, Deserialize)]
292struct CreateBillingSubscriptionBody {
293 github_user_id: i32,
294 product: ProductCode,
295}
296
297#[derive(Debug, Serialize)]
298struct CreateBillingSubscriptionResponse {
299 checkout_session_url: String,
300}
301
302/// Initiates a Stripe Checkout session for creating a billing subscription.
303async fn create_billing_subscription(
304 Extension(app): Extension<Arc<AppState>>,
305 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
306) -> Result<Json<CreateBillingSubscriptionResponse>> {
307 let user = app
308 .db
309 .get_user_by_github_user_id(body.github_user_id)
310 .await?
311 .ok_or_else(|| anyhow!("user not found"))?;
312
313 let Some(stripe_client) = app.stripe_client.clone() else {
314 log::error!("failed to retrieve Stripe client");
315 Err(Error::http(
316 StatusCode::NOT_IMPLEMENTED,
317 "not supported".into(),
318 ))?
319 };
320 let Some(stripe_billing) = app.stripe_billing.clone() else {
321 log::error!("failed to retrieve Stripe billing object");
322 Err(Error::http(
323 StatusCode::NOT_IMPLEMENTED,
324 "not supported".into(),
325 ))?
326 };
327
328 if app.db.has_active_billing_subscription(user.id).await? {
329 return Err(Error::http(
330 StatusCode::CONFLICT,
331 "user already has an active subscription".into(),
332 ));
333 }
334
335 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
336 if let Some(existing_billing_customer) = &existing_billing_customer {
337 if existing_billing_customer.has_overdue_invoices {
338 return Err(Error::http(
339 StatusCode::PAYMENT_REQUIRED,
340 "user has overdue invoices".into(),
341 ));
342 }
343 }
344
345 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
346 CustomerId::from_str(&existing_customer.stripe_customer_id)
347 .context("failed to parse customer ID")?
348 } else {
349 let existing_customer = if let Some(email) = user.email_address.as_deref() {
350 let customers = Customer::list(
351 &stripe_client,
352 &stripe::ListCustomers {
353 email: Some(email),
354 ..Default::default()
355 },
356 )
357 .await?;
358
359 customers.data.first().cloned()
360 } else {
361 None
362 };
363
364 if let Some(existing_customer) = existing_customer {
365 existing_customer.id
366 } else {
367 let customer = Customer::create(
368 &stripe_client,
369 CreateCustomer {
370 email: user.email_address.as_deref(),
371 ..Default::default()
372 },
373 )
374 .await?;
375
376 customer.id
377 }
378 };
379
380 let success_url = format!(
381 "{}/account?checkout_complete=1",
382 app.config.zed_dot_dev_url()
383 );
384
385 let checkout_session_url = match body.product {
386 ProductCode::ZedPro => {
387 stripe_billing
388 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
389 .await?
390 }
391 ProductCode::ZedProTrial => {
392 if let Some(existing_billing_customer) = &existing_billing_customer {
393 if existing_billing_customer.trial_started_at.is_some() {
394 return Err(Error::http(
395 StatusCode::FORBIDDEN,
396 "user already used free trial".into(),
397 ));
398 }
399 }
400
401 let feature_flags = app.db.get_user_flags(user.id).await?;
402
403 stripe_billing
404 .checkout_with_zed_pro_trial(
405 customer_id,
406 &user.github_login,
407 feature_flags,
408 &success_url,
409 )
410 .await?
411 }
412 ProductCode::ZedFree => {
413 stripe_billing
414 .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
415 .await?
416 }
417 };
418
419 Ok(Json(CreateBillingSubscriptionResponse {
420 checkout_session_url,
421 }))
422}
423
424#[derive(Debug, PartialEq, Deserialize)]
425#[serde(rename_all = "snake_case")]
426enum ManageSubscriptionIntent {
427 /// The user intends to manage their subscription.
428 ///
429 /// This will open the Stripe billing portal without putting the user in a specific flow.
430 ManageSubscription,
431 /// The user intends to update their payment method.
432 UpdatePaymentMethod,
433 /// The user intends to upgrade to Zed Pro.
434 UpgradeToPro,
435 /// The user intends to cancel their subscription.
436 Cancel,
437 /// The user intends to stop the cancellation of their subscription.
438 StopCancellation,
439}
440
441#[derive(Debug, Deserialize)]
442struct ManageBillingSubscriptionBody {
443 github_user_id: i32,
444 intent: ManageSubscriptionIntent,
445 /// The ID of the subscription to manage.
446 subscription_id: BillingSubscriptionId,
447 redirect_to: Option<String>,
448}
449
450#[derive(Debug, Serialize)]
451struct ManageBillingSubscriptionResponse {
452 billing_portal_session_url: Option<String>,
453}
454
455/// Initiates a Stripe customer portal session for managing a billing subscription.
456async fn manage_billing_subscription(
457 Extension(app): Extension<Arc<AppState>>,
458 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
459) -> Result<Json<ManageBillingSubscriptionResponse>> {
460 let user = app
461 .db
462 .get_user_by_github_user_id(body.github_user_id)
463 .await?
464 .ok_or_else(|| anyhow!("user not found"))?;
465
466 let Some(stripe_client) = app.stripe_client.clone() else {
467 log::error!("failed to retrieve Stripe client");
468 Err(Error::http(
469 StatusCode::NOT_IMPLEMENTED,
470 "not supported".into(),
471 ))?
472 };
473
474 let Some(stripe_billing) = app.stripe_billing.clone() else {
475 log::error!("failed to retrieve Stripe billing object");
476 Err(Error::http(
477 StatusCode::NOT_IMPLEMENTED,
478 "not supported".into(),
479 ))?
480 };
481
482 let customer = app
483 .db
484 .get_billing_customer_by_user_id(user.id)
485 .await?
486 .ok_or_else(|| anyhow!("billing customer not found"))?;
487 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
488 .context("failed to parse customer ID")?;
489
490 let subscription = app
491 .db
492 .get_billing_subscription_by_id(body.subscription_id)
493 .await?
494 .ok_or_else(|| anyhow!("subscription not found"))?;
495 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
496 .context("failed to parse subscription ID")?;
497
498 if body.intent == ManageSubscriptionIntent::StopCancellation {
499 let updated_stripe_subscription = Subscription::update(
500 &stripe_client,
501 &subscription_id,
502 stripe::UpdateSubscription {
503 cancel_at_period_end: Some(false),
504 ..Default::default()
505 },
506 )
507 .await?;
508
509 app.db
510 .update_billing_subscription(
511 subscription.id,
512 &UpdateBillingSubscriptionParams {
513 stripe_cancel_at: ActiveValue::set(
514 updated_stripe_subscription
515 .cancel_at
516 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
517 .map(|time| time.naive_utc()),
518 ),
519 ..Default::default()
520 },
521 )
522 .await?;
523
524 return Ok(Json(ManageBillingSubscriptionResponse {
525 billing_portal_session_url: None,
526 }));
527 }
528
529 let flow = match body.intent {
530 ManageSubscriptionIntent::ManageSubscription => None,
531 ManageSubscriptionIntent::UpgradeToPro => {
532 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
533 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
534
535 let stripe_subscription =
536 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
537
538 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
539 && stripe_subscription.items.data.iter().any(|item| {
540 item.price
541 .as_ref()
542 .map_or(false, |price| price.id == zed_pro_price_id)
543 });
544 if is_on_zed_pro_trial {
545 let payment_methods = PaymentMethod::list(
546 &stripe_client,
547 &stripe::ListPaymentMethods {
548 customer: Some(stripe_subscription.customer.id()),
549 ..Default::default()
550 },
551 )
552 .await?;
553
554 let has_payment_method = !payment_methods.data.is_empty();
555 if !has_payment_method {
556 return Err(Error::http(
557 StatusCode::BAD_REQUEST,
558 "missing payment method".into(),
559 ));
560 }
561
562 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
563 Subscription::update(
564 &stripe_client,
565 &stripe_subscription.id,
566 stripe::UpdateSubscription {
567 trial_end: Some(stripe::Scheduled::now()),
568 ..Default::default()
569 },
570 )
571 .await?;
572
573 return Ok(Json(ManageBillingSubscriptionResponse {
574 billing_portal_session_url: None,
575 }));
576 }
577
578 let subscription_item_to_update = stripe_subscription
579 .items
580 .data
581 .iter()
582 .find_map(|item| {
583 let price = item.price.as_ref()?;
584
585 if price.id == zed_free_price_id {
586 Some(item.id.clone())
587 } else {
588 None
589 }
590 })
591 .ok_or_else(|| anyhow!("No subscription item to update"))?;
592
593 Some(CreateBillingPortalSessionFlowData {
594 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
595 subscription_update_confirm: Some(
596 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
597 subscription: subscription.stripe_subscription_id,
598 items: vec![
599 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
600 id: subscription_item_to_update.to_string(),
601 price: Some(zed_pro_price_id.to_string()),
602 quantity: Some(1),
603 },
604 ],
605 discounts: None,
606 },
607 ),
608 ..Default::default()
609 })
610 }
611 ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
612 type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
613 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
614 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
615 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
616 return_url: format!(
617 "{}{path}",
618 app.config.zed_dot_dev_url(),
619 path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
620 ),
621 }),
622 ..Default::default()
623 }),
624 ..Default::default()
625 }),
626 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
627 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
628 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
629 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
630 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
631 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
632 }),
633 ..Default::default()
634 }),
635 subscription_cancel: Some(
636 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
637 subscription: subscription.stripe_subscription_id,
638 retention: None,
639 },
640 ),
641 ..Default::default()
642 }),
643 ManageSubscriptionIntent::StopCancellation => unreachable!(),
644 };
645
646 let mut params = CreateBillingPortalSession::new(customer_id);
647 params.flow_data = flow;
648 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
649 params.return_url = Some(&return_url);
650
651 let session = BillingPortalSession::create(&stripe_client, params).await?;
652
653 Ok(Json(ManageBillingSubscriptionResponse {
654 billing_portal_session_url: Some(session.url),
655 }))
656}
657
658#[derive(Debug, Deserialize)]
659struct MigrateToNewBillingBody {
660 github_user_id: i32,
661}
662
663#[derive(Debug, Serialize)]
664struct MigrateToNewBillingResponse {
665 /// The ID of the subscription that was canceled.
666 canceled_subscription_id: Option<String>,
667}
668
669async fn migrate_to_new_billing(
670 Extension(app): Extension<Arc<AppState>>,
671 extract::Json(body): extract::Json<MigrateToNewBillingBody>,
672) -> Result<Json<MigrateToNewBillingResponse>> {
673 let Some(stripe_client) = app.stripe_client.clone() else {
674 log::error!("failed to retrieve Stripe client");
675 Err(Error::http(
676 StatusCode::NOT_IMPLEMENTED,
677 "not supported".into(),
678 ))?
679 };
680
681 let user = app
682 .db
683 .get_user_by_github_user_id(body.github_user_id)
684 .await?
685 .ok_or_else(|| anyhow!("user not found"))?;
686
687 let old_billing_subscriptions_by_user = app
688 .db
689 .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
690 .await?;
691
692 let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
693 old_billing_subscriptions_by_user.get(&user.id)
694 {
695 let stripe_subscription_id = billing_subscription
696 .stripe_subscription_id
697 .parse::<stripe::SubscriptionId>()
698 .context("failed to parse Stripe subscription ID from database")?;
699
700 Subscription::cancel(
701 &stripe_client,
702 &stripe_subscription_id,
703 stripe::CancelSubscription {
704 invoice_now: Some(true),
705 ..Default::default()
706 },
707 )
708 .await?;
709
710 Some(stripe_subscription_id)
711 } else {
712 None
713 };
714
715 let all_feature_flags = app.db.list_feature_flags().await?;
716 let user_feature_flags = app.db.get_user_flags(user.id).await?;
717
718 for feature_flag in ["new-billing", "assistant2"] {
719 let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
720 if already_in_feature_flag {
721 continue;
722 }
723
724 let feature_flag = all_feature_flags
725 .iter()
726 .find(|flag| flag.flag == feature_flag)
727 .context("failed to find feature flag: {feature_flag:?}")?;
728
729 app.db.add_user_flag(user.id, feature_flag.id).await?;
730 }
731
732 Ok(Json(MigrateToNewBillingResponse {
733 canceled_subscription_id: canceled_subscription_id
734 .map(|subscription_id| subscription_id.to_string()),
735 }))
736}
737
738#[derive(Debug, Deserialize)]
739struct SyncBillingSubscriptionBody {
740 github_user_id: i32,
741}
742
743#[derive(Debug, Serialize)]
744struct SyncBillingSubscriptionResponse {
745 stripe_customer_id: String,
746}
747
748async fn sync_billing_subscription(
749 Extension(app): Extension<Arc<AppState>>,
750 extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
751) -> Result<Json<SyncBillingSubscriptionResponse>> {
752 let Some(stripe_client) = app.stripe_client.clone() else {
753 log::error!("failed to retrieve Stripe client");
754 Err(Error::http(
755 StatusCode::NOT_IMPLEMENTED,
756 "not supported".into(),
757 ))?
758 };
759
760 let user = app
761 .db
762 .get_user_by_github_user_id(body.github_user_id)
763 .await?
764 .ok_or_else(|| anyhow!("user not found"))?;
765
766 let billing_customer = app
767 .db
768 .get_billing_customer_by_user_id(user.id)
769 .await?
770 .ok_or_else(|| anyhow!("billing customer not found"))?;
771 let stripe_customer_id = billing_customer
772 .stripe_customer_id
773 .parse::<stripe::CustomerId>()
774 .context("failed to parse Stripe customer ID from database")?;
775
776 let subscriptions = Subscription::list(
777 &stripe_client,
778 &stripe::ListSubscriptions {
779 customer: Some(stripe_customer_id),
780 // Sync all non-canceled subscriptions.
781 status: None,
782 ..Default::default()
783 },
784 )
785 .await?;
786
787 for subscription in subscriptions.data {
788 let subscription_id = subscription.id.clone();
789
790 sync_subscription(&app, &stripe_client, subscription)
791 .await
792 .with_context(|| {
793 format!(
794 "failed to sync subscription {subscription_id} for user {}",
795 user.id,
796 )
797 })?;
798 }
799
800 Ok(Json(SyncBillingSubscriptionResponse {
801 stripe_customer_id: billing_customer.stripe_customer_id.clone(),
802 }))
803}
804
805/// The amount of time we wait in between each poll of Stripe events.
806///
807/// This value should strike a balance between:
808/// 1. Being short enough that we update quickly when something in Stripe changes
809/// 2. Being long enough that we don't eat into our rate limits.
810///
811/// As a point of reference, the Sequin folks say they have this at **500ms**:
812///
813/// > We poll the Stripe /events endpoint every 500ms per account
814/// >
815/// > — https://blog.sequinstream.com/events-not-webhooks/
816const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
817
818/// The maximum number of events to return per page.
819///
820/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
821///
822/// > Limit can range between 1 and 100, and the default is 10.
823const EVENTS_LIMIT_PER_PAGE: u64 = 100;
824
825/// The number of pages consisting entirely of already-processed events that we
826/// will see before we stop retrieving events.
827///
828/// This is used to prevent over-fetching the Stripe events API for events we've
829/// already seen and processed.
830const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
831
832/// Polls the Stripe events API periodically to reconcile the records in our
833/// database with the data in Stripe.
834pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
835 let Some(stripe_client) = app.stripe_client.clone() else {
836 log::warn!("failed to retrieve Stripe client");
837 return;
838 };
839
840 let executor = app.executor.clone();
841 executor.spawn_detached({
842 let executor = executor.clone();
843 async move {
844 loop {
845 poll_stripe_events(&app, &rpc_server, &stripe_client)
846 .await
847 .log_err();
848
849 executor.sleep(POLL_EVENTS_INTERVAL).await;
850 }
851 }
852 });
853}
854
855async fn poll_stripe_events(
856 app: &Arc<AppState>,
857 rpc_server: &Arc<Server>,
858 stripe_client: &stripe::Client,
859) -> anyhow::Result<()> {
860 fn event_type_to_string(event_type: EventType) -> String {
861 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
862 // so we need to unquote it.
863 event_type.to_string().trim_matches('"').to_string()
864 }
865
866 let event_types = [
867 EventType::CustomerCreated,
868 EventType::CustomerUpdated,
869 EventType::CustomerSubscriptionCreated,
870 EventType::CustomerSubscriptionUpdated,
871 EventType::CustomerSubscriptionPaused,
872 EventType::CustomerSubscriptionResumed,
873 EventType::CustomerSubscriptionDeleted,
874 ]
875 .into_iter()
876 .map(event_type_to_string)
877 .collect::<Vec<_>>();
878
879 let mut pages_of_already_processed_events = 0;
880 let mut unprocessed_events = Vec::new();
881
882 log::info!(
883 "Stripe events: starting retrieval for {}",
884 event_types.join(", ")
885 );
886 let mut params = ListEvents::new();
887 params.types = Some(event_types.clone());
888 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
889
890 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
891 .await?
892 .paginate(params);
893
894 loop {
895 let processed_event_ids = {
896 let event_ids = event_pages
897 .page
898 .data
899 .iter()
900 .map(|event| event.id.as_str())
901 .collect::<Vec<_>>();
902 app.db
903 .get_processed_stripe_events_by_event_ids(&event_ids)
904 .await?
905 .into_iter()
906 .map(|event| event.stripe_event_id)
907 .collect::<Vec<_>>()
908 };
909
910 let mut processed_events_in_page = 0;
911 let events_in_page = event_pages.page.data.len();
912 for event in &event_pages.page.data {
913 if processed_event_ids.contains(&event.id.to_string()) {
914 processed_events_in_page += 1;
915 log::debug!("Stripe events: already processed '{}', skipping", event.id);
916 } else {
917 unprocessed_events.push(event.clone());
918 }
919 }
920
921 if processed_events_in_page == events_in_page {
922 pages_of_already_processed_events += 1;
923 }
924
925 if event_pages.page.has_more {
926 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
927 {
928 log::info!(
929 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
930 );
931 break;
932 } else {
933 log::info!("Stripe events: retrieving next page");
934 event_pages = event_pages.next(&stripe_client).await?;
935 }
936 } else {
937 break;
938 }
939 }
940
941 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
942
943 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
944 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
945
946 for event in unprocessed_events {
947 let event_id = event.id.clone();
948 let processed_event_params = CreateProcessedStripeEventParams {
949 stripe_event_id: event.id.to_string(),
950 stripe_event_type: event_type_to_string(event.type_),
951 stripe_event_created_timestamp: event.created,
952 };
953
954 // If the event has happened too far in the past, we don't want to
955 // process it and risk overwriting other more-recent updates.
956 //
957 // 1 day was chosen arbitrarily. This could be made longer or shorter.
958 let one_day = Duration::from_secs(24 * 60 * 60);
959 let a_day_ago = Utc::now() - one_day;
960 if a_day_ago.timestamp() > event.created {
961 log::info!(
962 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
963 event_id
964 );
965 app.db
966 .create_processed_stripe_event(&processed_event_params)
967 .await?;
968
969 return Ok(());
970 }
971
972 let process_result = match event.type_ {
973 EventType::CustomerCreated | EventType::CustomerUpdated => {
974 handle_customer_event(app, stripe_client, event).await
975 }
976 EventType::CustomerSubscriptionCreated
977 | EventType::CustomerSubscriptionUpdated
978 | EventType::CustomerSubscriptionPaused
979 | EventType::CustomerSubscriptionResumed
980 | EventType::CustomerSubscriptionDeleted => {
981 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
982 }
983 _ => Ok(()),
984 };
985
986 if let Some(()) = process_result
987 .with_context(|| format!("failed to process event {event_id} successfully"))
988 .log_err()
989 {
990 app.db
991 .create_processed_stripe_event(&processed_event_params)
992 .await?;
993 }
994 }
995
996 Ok(())
997}
998
999async fn handle_customer_event(
1000 app: &Arc<AppState>,
1001 _stripe_client: &stripe::Client,
1002 event: stripe::Event,
1003) -> anyhow::Result<()> {
1004 let EventObject::Customer(customer) = event.data.object else {
1005 bail!("unexpected event payload for {}", event.id);
1006 };
1007
1008 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1009
1010 let Some(email) = customer.email else {
1011 log::info!("Stripe customer has no email: skipping");
1012 return Ok(());
1013 };
1014
1015 let Some(user) = app.db.get_user_by_email(&email).await? else {
1016 log::info!("no user found for email: skipping");
1017 return Ok(());
1018 };
1019
1020 if let Some(existing_customer) = app
1021 .db
1022 .get_billing_customer_by_stripe_customer_id(&customer.id)
1023 .await?
1024 {
1025 app.db
1026 .update_billing_customer(
1027 existing_customer.id,
1028 &UpdateBillingCustomerParams {
1029 // For now we just leave the information as-is, as it is not
1030 // likely to change.
1031 ..Default::default()
1032 },
1033 )
1034 .await?;
1035 } else {
1036 app.db
1037 .create_billing_customer(&CreateBillingCustomerParams {
1038 user_id: user.id,
1039 stripe_customer_id: customer.id.to_string(),
1040 })
1041 .await?;
1042 }
1043
1044 Ok(())
1045}
1046
1047async fn sync_subscription(
1048 app: &Arc<AppState>,
1049 stripe_client: &stripe::Client,
1050 subscription: stripe::Subscription,
1051) -> anyhow::Result<billing_customer::Model> {
1052 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1053 stripe_billing
1054 .determine_subscription_kind(&subscription)
1055 .await
1056 } else {
1057 None
1058 };
1059
1060 let billing_customer =
1061 find_or_create_billing_customer(app, stripe_client, subscription.customer)
1062 .await?
1063 .ok_or_else(|| anyhow!("billing customer not found"))?;
1064
1065 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1066 if subscription.status == SubscriptionStatus::Trialing {
1067 let current_period_start =
1068 DateTime::from_timestamp(subscription.current_period_start, 0)
1069 .ok_or_else(|| anyhow!("No trial subscription period start"))?;
1070
1071 app.db
1072 .update_billing_customer(
1073 billing_customer.id,
1074 &UpdateBillingCustomerParams {
1075 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1076 ..Default::default()
1077 },
1078 )
1079 .await?;
1080 }
1081 }
1082
1083 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1084 && subscription
1085 .cancellation_details
1086 .as_ref()
1087 .and_then(|details| details.reason)
1088 .map_or(false, |reason| {
1089 reason == CancellationDetailsReason::PaymentFailed
1090 });
1091
1092 if was_canceled_due_to_payment_failure {
1093 app.db
1094 .update_billing_customer(
1095 billing_customer.id,
1096 &UpdateBillingCustomerParams {
1097 has_overdue_invoices: ActiveValue::set(true),
1098 ..Default::default()
1099 },
1100 )
1101 .await?;
1102 }
1103
1104 if let Some(existing_subscription) = app
1105 .db
1106 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1107 .await?
1108 {
1109 app.db
1110 .update_billing_subscription(
1111 existing_subscription.id,
1112 &UpdateBillingSubscriptionParams {
1113 billing_customer_id: ActiveValue::set(billing_customer.id),
1114 kind: ActiveValue::set(subscription_kind),
1115 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1116 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1117 stripe_cancel_at: ActiveValue::set(
1118 subscription
1119 .cancel_at
1120 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1121 .map(|time| time.naive_utc()),
1122 ),
1123 stripe_cancellation_reason: ActiveValue::set(
1124 subscription
1125 .cancellation_details
1126 .and_then(|details| details.reason)
1127 .map(|reason| reason.into()),
1128 ),
1129 stripe_current_period_start: ActiveValue::set(Some(
1130 subscription.current_period_start,
1131 )),
1132 stripe_current_period_end: ActiveValue::set(Some(
1133 subscription.current_period_end,
1134 )),
1135 },
1136 )
1137 .await?;
1138 } else {
1139 // If the user already has an active billing subscription, ignore the
1140 // event and return an `Ok` to signal that it was processed
1141 // successfully.
1142 //
1143 // There is the possibility that this could cause us to not create a
1144 // subscription in the following scenario:
1145 //
1146 // 1. User has an active subscription A
1147 // 2. User cancels subscription A
1148 // 3. User creates a new subscription B
1149 // 4. We process the new subscription B before the cancellation of subscription A
1150 // 5. User ends up with no subscriptions
1151 //
1152 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1153 if app
1154 .db
1155 .has_active_billing_subscription(billing_customer.user_id)
1156 .await?
1157 {
1158 log::info!(
1159 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1160 user_id = billing_customer.user_id,
1161 subscription_id = subscription.id
1162 );
1163 return Ok(billing_customer);
1164 }
1165
1166 app.db
1167 .create_billing_subscription(&CreateBillingSubscriptionParams {
1168 billing_customer_id: billing_customer.id,
1169 kind: subscription_kind,
1170 stripe_subscription_id: subscription.id.to_string(),
1171 stripe_subscription_status: subscription.status.into(),
1172 stripe_cancellation_reason: subscription
1173 .cancellation_details
1174 .and_then(|details| details.reason)
1175 .map(|reason| reason.into()),
1176 stripe_current_period_start: Some(subscription.current_period_start),
1177 stripe_current_period_end: Some(subscription.current_period_end),
1178 })
1179 .await?;
1180 }
1181
1182 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1183 if subscription.status == SubscriptionStatus::Canceled
1184 || subscription.status == SubscriptionStatus::Paused
1185 {
1186 let stripe_customer_id = billing_customer
1187 .stripe_customer_id
1188 .parse::<stripe::CustomerId>()
1189 .context("failed to parse Stripe customer ID from database")?;
1190
1191 stripe_billing
1192 .subscribe_to_zed_free(stripe_customer_id)
1193 .await?;
1194 }
1195 }
1196
1197 Ok(billing_customer)
1198}
1199
1200async fn handle_customer_subscription_event(
1201 app: &Arc<AppState>,
1202 rpc_server: &Arc<Server>,
1203 stripe_client: &stripe::Client,
1204 event: stripe::Event,
1205) -> anyhow::Result<()> {
1206 let EventObject::Subscription(subscription) = event.data.object else {
1207 bail!("unexpected event payload for {}", event.id);
1208 };
1209
1210 log::info!("handling Stripe {} event: {}", event.type_, event.id);
1211
1212 let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1213
1214 // When the user's subscription changes, push down any changes to their plan.
1215 rpc_server
1216 .update_plan_for_user(billing_customer.user_id)
1217 .await
1218 .trace_err();
1219
1220 // When the user's subscription changes, we want to refresh their LLM tokens
1221 // to either grant/revoke access.
1222 rpc_server
1223 .refresh_llm_tokens_for_user(billing_customer.user_id)
1224 .await;
1225
1226 Ok(())
1227}
1228
1229#[derive(Debug, Deserialize)]
1230struct GetMonthlySpendParams {
1231 github_user_id: i32,
1232}
1233
1234#[derive(Debug, Serialize)]
1235struct GetMonthlySpendResponse {
1236 monthly_free_tier_spend_in_cents: u32,
1237 monthly_free_tier_allowance_in_cents: u32,
1238 monthly_spend_in_cents: u32,
1239}
1240
1241async fn get_monthly_spend(
1242 Extension(app): Extension<Arc<AppState>>,
1243 Query(params): Query<GetMonthlySpendParams>,
1244) -> Result<Json<GetMonthlySpendResponse>> {
1245 let user = app
1246 .db
1247 .get_user_by_github_user_id(params.github_user_id)
1248 .await?
1249 .ok_or_else(|| anyhow!("user not found"))?;
1250
1251 let Some(llm_db) = app.llm_db.clone() else {
1252 return Err(Error::http(
1253 StatusCode::NOT_IMPLEMENTED,
1254 "LLM database not available".into(),
1255 ));
1256 };
1257
1258 let free_tier = user
1259 .custom_llm_monthly_allowance_in_cents
1260 .map(|allowance| Cents(allowance as u32))
1261 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1262
1263 let spending_for_month = llm_db
1264 .get_user_spending_for_month(user.id, Utc::now())
1265 .await?;
1266
1267 let free_tier_spend = Cents::min(spending_for_month, free_tier);
1268 let monthly_spend = spending_for_month.saturating_sub(free_tier);
1269
1270 Ok(Json(GetMonthlySpendResponse {
1271 monthly_free_tier_spend_in_cents: free_tier_spend.0,
1272 monthly_free_tier_allowance_in_cents: free_tier.0,
1273 monthly_spend_in_cents: monthly_spend.0,
1274 }))
1275}
1276
1277#[derive(Debug, Deserialize)]
1278struct GetCurrentUsageParams {
1279 github_user_id: i32,
1280}
1281
1282#[derive(Debug, Serialize)]
1283struct UsageCounts {
1284 pub used: i32,
1285 pub limit: Option<i32>,
1286 pub remaining: Option<i32>,
1287}
1288
1289#[derive(Debug, Serialize)]
1290struct ModelRequestUsage {
1291 pub model: String,
1292 pub mode: CompletionMode,
1293 pub requests: i32,
1294}
1295
1296#[derive(Debug, Serialize)]
1297struct CurrentUsage {
1298 pub model_requests: UsageCounts,
1299 pub model_request_usage: Vec<ModelRequestUsage>,
1300 pub edit_predictions: UsageCounts,
1301}
1302
1303#[derive(Debug, Default, Serialize)]
1304struct GetCurrentUsageResponse {
1305 pub plan: String,
1306 pub current_usage: Option<CurrentUsage>,
1307}
1308
1309async fn get_current_usage(
1310 Extension(app): Extension<Arc<AppState>>,
1311 Query(params): Query<GetCurrentUsageParams>,
1312) -> Result<Json<GetCurrentUsageResponse>> {
1313 let user = app
1314 .db
1315 .get_user_by_github_user_id(params.github_user_id)
1316 .await?
1317 .ok_or_else(|| anyhow!("user not found"))?;
1318
1319 let feature_flags = app.db.get_user_flags(user.id).await?;
1320 let has_extended_trial = feature_flags
1321 .iter()
1322 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1323
1324 let Some(llm_db) = app.llm_db.clone() else {
1325 return Err(Error::http(
1326 StatusCode::NOT_IMPLEMENTED,
1327 "LLM database not available".into(),
1328 ));
1329 };
1330
1331 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1332 return Ok(Json(GetCurrentUsageResponse::default()));
1333 };
1334
1335 let subscription_period = maybe!({
1336 let period_start_at = subscription.current_period_start_at()?;
1337 let period_end_at = subscription.current_period_end_at()?;
1338
1339 Some((period_start_at, period_end_at))
1340 });
1341
1342 let Some((period_start_at, period_end_at)) = subscription_period else {
1343 return Ok(Json(GetCurrentUsageResponse::default()));
1344 };
1345
1346 let usage = llm_db
1347 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1348 .await?;
1349
1350 let plan = usage
1351 .as_ref()
1352 .map(|usage| usage.plan.into())
1353 .unwrap_or_else(|| {
1354 subscription
1355 .kind
1356 .map(Into::into)
1357 .unwrap_or(zed_llm_client::Plan::ZedFree)
1358 });
1359
1360 let model_requests_limit = match plan.model_requests_limit() {
1361 zed_llm_client::UsageLimit::Limited(limit) => {
1362 let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1363 1_000
1364 } else {
1365 limit
1366 };
1367
1368 Some(limit)
1369 }
1370 zed_llm_client::UsageLimit::Unlimited => None,
1371 };
1372
1373 let edit_predictions_limit = match plan.edit_predictions_limit() {
1374 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1375 zed_llm_client::UsageLimit::Unlimited => None,
1376 };
1377
1378 let Some(usage) = usage else {
1379 return Ok(Json(GetCurrentUsageResponse {
1380 plan: plan.as_str().to_string(),
1381 current_usage: Some(CurrentUsage {
1382 model_requests: UsageCounts {
1383 used: 0,
1384 limit: model_requests_limit,
1385 remaining: model_requests_limit,
1386 },
1387 model_request_usage: Vec::new(),
1388 edit_predictions: UsageCounts {
1389 used: 0,
1390 limit: edit_predictions_limit,
1391 remaining: edit_predictions_limit,
1392 },
1393 }),
1394 }));
1395 };
1396
1397 let subscription_usage_meters = llm_db
1398 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1399 .await?;
1400
1401 let model_request_usage = subscription_usage_meters
1402 .into_iter()
1403 .filter_map(|(usage_meter, _usage)| {
1404 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1405
1406 Some(ModelRequestUsage {
1407 model: model.name.clone(),
1408 mode: usage_meter.mode,
1409 requests: usage_meter.requests,
1410 })
1411 })
1412 .collect::<Vec<_>>();
1413
1414 Ok(Json(GetCurrentUsageResponse {
1415 plan: plan.as_str().to_string(),
1416 current_usage: Some(CurrentUsage {
1417 model_requests: UsageCounts {
1418 used: usage.model_requests,
1419 limit: model_requests_limit,
1420 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1421 },
1422 model_request_usage,
1423 edit_predictions: UsageCounts {
1424 used: usage.edit_predictions,
1425 limit: edit_predictions_limit,
1426 remaining: edit_predictions_limit
1427 .map(|limit| (limit - usage.edit_predictions).max(0)),
1428 },
1429 }),
1430 }))
1431}
1432
1433impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1434 fn from(value: SubscriptionStatus) -> Self {
1435 match value {
1436 SubscriptionStatus::Incomplete => Self::Incomplete,
1437 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1438 SubscriptionStatus::Trialing => Self::Trialing,
1439 SubscriptionStatus::Active => Self::Active,
1440 SubscriptionStatus::PastDue => Self::PastDue,
1441 SubscriptionStatus::Canceled => Self::Canceled,
1442 SubscriptionStatus::Unpaid => Self::Unpaid,
1443 SubscriptionStatus::Paused => Self::Paused,
1444 }
1445 }
1446}
1447
1448impl From<CancellationDetailsReason> for StripeCancellationReason {
1449 fn from(value: CancellationDetailsReason) -> Self {
1450 match value {
1451 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1452 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1453 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1454 }
1455 }
1456}
1457
1458/// Finds or creates a billing customer using the provided customer.
1459async fn find_or_create_billing_customer(
1460 app: &Arc<AppState>,
1461 stripe_client: &stripe::Client,
1462 customer_or_id: Expandable<Customer>,
1463) -> anyhow::Result<Option<billing_customer::Model>> {
1464 let customer_id = match &customer_or_id {
1465 Expandable::Id(id) => id,
1466 Expandable::Object(customer) => customer.id.as_ref(),
1467 };
1468
1469 // If we already have a billing customer record associated with the Stripe customer,
1470 // there's nothing more we need to do.
1471 if let Some(billing_customer) = app
1472 .db
1473 .get_billing_customer_by_stripe_customer_id(customer_id)
1474 .await?
1475 {
1476 return Ok(Some(billing_customer));
1477 }
1478
1479 // If all we have is a customer ID, resolve it to a full customer record by
1480 // hitting the Stripe API.
1481 let customer = match customer_or_id {
1482 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1483 Expandable::Object(customer) => *customer,
1484 };
1485
1486 let Some(email) = customer.email else {
1487 return Ok(None);
1488 };
1489
1490 let Some(user) = app.db.get_user_by_email(&email).await? else {
1491 return Ok(None);
1492 };
1493
1494 let billing_customer = app
1495 .db
1496 .create_billing_customer(&CreateBillingCustomerParams {
1497 user_id: user.id,
1498 stripe_customer_id: customer.id.to_string(),
1499 })
1500 .await?;
1501
1502 Ok(Some(billing_customer))
1503}
1504
1505const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1506
1507pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1508 let Some(stripe_billing) = app.stripe_billing.clone() else {
1509 log::warn!("failed to retrieve Stripe billing object");
1510 return;
1511 };
1512 let Some(llm_db) = app.llm_db.clone() else {
1513 log::warn!("failed to retrieve LLM database");
1514 return;
1515 };
1516
1517 let executor = app.executor.clone();
1518 executor.spawn_detached({
1519 let executor = executor.clone();
1520 async move {
1521 loop {
1522 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1523 .await
1524 .context("failed to sync LLM request usage to Stripe")
1525 .trace_err();
1526 executor
1527 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1528 .await;
1529 }
1530 }
1531 });
1532}
1533
1534async fn sync_model_request_usage_with_stripe(
1535 app: &Arc<AppState>,
1536 llm_db: &Arc<LlmDatabase>,
1537 stripe_billing: &Arc<StripeBilling>,
1538) -> anyhow::Result<()> {
1539 let staff_users = app.db.get_staff_users().await?;
1540 let staff_user_ids = staff_users
1541 .iter()
1542 .map(|user| user.id)
1543 .collect::<HashSet<UserId>>();
1544
1545 let usage_meters = llm_db
1546 .get_current_subscription_usage_meters(Utc::now())
1547 .await?;
1548 let usage_meters = usage_meters
1549 .into_iter()
1550 .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1551 .collect::<Vec<_>>();
1552 let user_ids = usage_meters
1553 .iter()
1554 .map(|(_, usage)| usage.user_id)
1555 .collect::<HashSet<UserId>>();
1556 let billing_subscriptions = app
1557 .db
1558 .get_active_zed_pro_billing_subscriptions(user_ids)
1559 .await?;
1560
1561 let claude_3_5_sonnet = stripe_billing
1562 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1563 .await?;
1564 let claude_3_7_sonnet = stripe_billing
1565 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1566 .await?;
1567 let claude_3_7_sonnet_max = stripe_billing
1568 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1569 .await?;
1570
1571 for (usage_meter, usage) in usage_meters {
1572 maybe!(async {
1573 let Some((billing_customer, billing_subscription)) =
1574 billing_subscriptions.get(&usage.user_id)
1575 else {
1576 bail!(
1577 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1578 usage.user_id
1579 );
1580 };
1581
1582 let stripe_customer_id = billing_customer
1583 .stripe_customer_id
1584 .parse::<stripe::CustomerId>()
1585 .context("failed to parse Stripe customer ID from database")?;
1586 let stripe_subscription_id = billing_subscription
1587 .stripe_subscription_id
1588 .parse::<stripe::SubscriptionId>()
1589 .context("failed to parse Stripe subscription ID from database")?;
1590
1591 let model = llm_db.model_by_id(usage_meter.model_id)?;
1592
1593 let (price, meter_event_name) = match model.name.as_str() {
1594 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1595 "claude-3-7-sonnet" => match usage_meter.mode {
1596 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1597 CompletionMode::Max => {
1598 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1599 }
1600 },
1601 model_name => {
1602 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1603 }
1604 };
1605
1606 stripe_billing
1607 .subscribe_to_price(&stripe_subscription_id, price)
1608 .await?;
1609 stripe_billing
1610 .bill_model_request_usage(
1611 &stripe_customer_id,
1612 meter_event_name,
1613 usage_meter.requests,
1614 )
1615 .await?;
1616
1617 Ok(())
1618 })
1619 .await
1620 .log_err();
1621 }
1622
1623 Ok(())
1624}