1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
30use crate::rpc::{ResultExt as _, Server};
31use crate::{AppState, Cents, Error, Result};
32use crate::{db::UserId, llm::db::LlmDatabase};
33use crate::{
34 db::{
35 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
36 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
37 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
38 },
39 stripe_billing::StripeBilling,
40};
41
42pub fn router() -> Router {
43 Router::new()
44 .route(
45 "/billing/preferences",
46 get(get_billing_preferences).put(update_billing_preferences),
47 )
48 .route(
49 "/billing/subscriptions",
50 get(list_billing_subscriptions).post(create_billing_subscription),
51 )
52 .route(
53 "/billing/subscriptions/manage",
54 post(manage_billing_subscription),
55 )
56 .route("/billing/monthly_spend", get(get_monthly_spend))
57 .route("/billing/usage", get(get_current_usage))
58}
59
60#[derive(Debug, Deserialize)]
61struct GetBillingPreferencesParams {
62 github_user_id: i32,
63}
64
65#[derive(Debug, Serialize)]
66struct BillingPreferencesResponse {
67 max_monthly_llm_usage_spending_in_cents: i32,
68 model_request_overages_enabled: bool,
69 model_request_overages_spend_limit_in_cents: i32,
70}
71
72async fn get_billing_preferences(
73 Extension(app): Extension<Arc<AppState>>,
74 Query(params): Query<GetBillingPreferencesParams>,
75) -> Result<Json<BillingPreferencesResponse>> {
76 let user = app
77 .db
78 .get_user_by_github_user_id(params.github_user_id)
79 .await?
80 .ok_or_else(|| anyhow!("user not found"))?;
81
82 let preferences = app.db.get_billing_preferences(user.id).await?;
83
84 Ok(Json(BillingPreferencesResponse {
85 max_monthly_llm_usage_spending_in_cents: preferences
86 .as_ref()
87 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
88 preferences.max_monthly_llm_usage_spending_in_cents
89 }),
90 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
91 preferences.model_request_overages_enabled
92 }),
93 model_request_overages_spend_limit_in_cents: preferences
94 .as_ref()
95 .map_or(0, |preferences| {
96 preferences.model_request_overages_spend_limit_in_cents
97 }),
98 }))
99}
100
101#[derive(Debug, Deserialize)]
102struct UpdateBillingPreferencesBody {
103 github_user_id: i32,
104 #[serde(default)]
105 max_monthly_llm_usage_spending_in_cents: i32,
106 #[serde(default)]
107 model_request_overages_enabled: bool,
108 #[serde(default)]
109 model_request_overages_spend_limit_in_cents: i32,
110}
111
112async fn update_billing_preferences(
113 Extension(app): Extension<Arc<AppState>>,
114 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
115 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
116) -> Result<Json<BillingPreferencesResponse>> {
117 let user = app
118 .db
119 .get_user_by_github_user_id(body.github_user_id)
120 .await?
121 .ok_or_else(|| anyhow!("user not found"))?;
122
123 let max_monthly_llm_usage_spending_in_cents =
124 body.max_monthly_llm_usage_spending_in_cents.max(0);
125 let model_request_overages_spend_limit_in_cents =
126 body.model_request_overages_spend_limit_in_cents.max(0);
127
128 let billing_preferences =
129 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
130 app.db
131 .update_billing_preferences(
132 user.id,
133 &UpdateBillingPreferencesParams {
134 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
135 max_monthly_llm_usage_spending_in_cents,
136 ),
137 model_request_overages_enabled: ActiveValue::set(
138 body.model_request_overages_enabled,
139 ),
140 model_request_overages_spend_limit_in_cents: ActiveValue::set(
141 model_request_overages_spend_limit_in_cents,
142 ),
143 },
144 )
145 .await?
146 } else {
147 app.db
148 .create_billing_preferences(
149 user.id,
150 &crate::db::CreateBillingPreferencesParams {
151 max_monthly_llm_usage_spending_in_cents,
152 model_request_overages_enabled: body.model_request_overages_enabled,
153 model_request_overages_spend_limit_in_cents,
154 },
155 )
156 .await?
157 };
158
159 SnowflakeRow::new(
160 "Billing Preferences Updated",
161 Some(user.metrics_id),
162 user.admin,
163 None,
164 json!({
165 "user_id": user.id,
166 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
167 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
168 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
169 }),
170 )
171 .write(&app.kinesis_client, &app.config.kinesis_stream)
172 .await
173 .log_err();
174
175 rpc_server.refresh_llm_tokens_for_user(user.id).await;
176
177 Ok(Json(BillingPreferencesResponse {
178 max_monthly_llm_usage_spending_in_cents: billing_preferences
179 .max_monthly_llm_usage_spending_in_cents,
180 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
181 model_request_overages_spend_limit_in_cents: billing_preferences
182 .model_request_overages_spend_limit_in_cents,
183 }))
184}
185
186#[derive(Debug, Deserialize)]
187struct ListBillingSubscriptionsParams {
188 github_user_id: i32,
189}
190
191#[derive(Debug, Serialize)]
192struct BillingSubscriptionJson {
193 id: BillingSubscriptionId,
194 name: String,
195 status: StripeSubscriptionStatus,
196 trial_end_at: Option<String>,
197 cancel_at: Option<String>,
198 /// Whether this subscription can be canceled.
199 is_cancelable: bool,
200}
201
202#[derive(Debug, Serialize)]
203struct ListBillingSubscriptionsResponse {
204 subscriptions: Vec<BillingSubscriptionJson>,
205}
206
207async fn list_billing_subscriptions(
208 Extension(app): Extension<Arc<AppState>>,
209 Query(params): Query<ListBillingSubscriptionsParams>,
210) -> Result<Json<ListBillingSubscriptionsResponse>> {
211 let user = app
212 .db
213 .get_user_by_github_user_id(params.github_user_id)
214 .await?
215 .ok_or_else(|| anyhow!("user not found"))?;
216
217 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
218
219 Ok(Json(ListBillingSubscriptionsResponse {
220 subscriptions: subscriptions
221 .into_iter()
222 .map(|subscription| BillingSubscriptionJson {
223 id: subscription.id,
224 name: match subscription.kind {
225 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
226 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
227 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
228 None => "Zed LLM Usage".to_string(),
229 },
230 status: subscription.stripe_subscription_status,
231 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
232 maybe!({
233 let end_at = subscription.stripe_current_period_end?;
234 let end_at = DateTime::from_timestamp(end_at, 0)?;
235
236 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
237 })
238 } else {
239 None
240 },
241 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
242 cancel_at
243 .and_utc()
244 .to_rfc3339_opts(SecondsFormat::Millis, true)
245 }),
246 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
247 && subscription.stripe_cancel_at.is_none(),
248 })
249 .collect(),
250 }))
251}
252
253#[derive(Debug, Clone, Copy, Deserialize)]
254#[serde(rename_all = "snake_case")]
255enum ProductCode {
256 ZedPro,
257 ZedProTrial,
258}
259
260#[derive(Debug, Deserialize)]
261struct CreateBillingSubscriptionBody {
262 github_user_id: i32,
263 product: Option<ProductCode>,
264}
265
266#[derive(Debug, Serialize)]
267struct CreateBillingSubscriptionResponse {
268 checkout_session_url: String,
269}
270
271/// Initiates a Stripe Checkout session for creating a billing subscription.
272async fn create_billing_subscription(
273 Extension(app): Extension<Arc<AppState>>,
274 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
275) -> Result<Json<CreateBillingSubscriptionResponse>> {
276 let user = app
277 .db
278 .get_user_by_github_user_id(body.github_user_id)
279 .await?
280 .ok_or_else(|| anyhow!("user not found"))?;
281
282 let Some(stripe_client) = app.stripe_client.clone() else {
283 log::error!("failed to retrieve Stripe client");
284 Err(Error::http(
285 StatusCode::NOT_IMPLEMENTED,
286 "not supported".into(),
287 ))?
288 };
289 let Some(stripe_billing) = app.stripe_billing.clone() else {
290 log::error!("failed to retrieve Stripe billing object");
291 Err(Error::http(
292 StatusCode::NOT_IMPLEMENTED,
293 "not supported".into(),
294 ))?
295 };
296 let Some(llm_db) = app.llm_db.clone() else {
297 log::error!("failed to retrieve LLM database");
298 Err(Error::http(
299 StatusCode::NOT_IMPLEMENTED,
300 "not supported".into(),
301 ))?
302 };
303
304 if app.db.has_active_billing_subscription(user.id).await? {
305 return Err(Error::http(
306 StatusCode::CONFLICT,
307 "user already has an active subscription".into(),
308 ));
309 }
310
311 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
312 if let Some(existing_billing_customer) = &existing_billing_customer {
313 if existing_billing_customer.has_overdue_invoices {
314 return Err(Error::http(
315 StatusCode::PAYMENT_REQUIRED,
316 "user has overdue invoices".into(),
317 ));
318 }
319 }
320
321 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
322 CustomerId::from_str(&existing_customer.stripe_customer_id)
323 .context("failed to parse customer ID")?
324 } else {
325 let existing_customer = if let Some(email) = user.email_address.as_deref() {
326 let customers = Customer::list(
327 &stripe_client,
328 &stripe::ListCustomers {
329 email: Some(email),
330 ..Default::default()
331 },
332 )
333 .await?;
334
335 customers.data.first().cloned()
336 } else {
337 None
338 };
339
340 if let Some(existing_customer) = existing_customer {
341 existing_customer.id
342 } else {
343 let customer = Customer::create(
344 &stripe_client,
345 CreateCustomer {
346 email: user.email_address.as_deref(),
347 ..Default::default()
348 },
349 )
350 .await?;
351
352 customer.id
353 }
354 };
355
356 let success_url = format!(
357 "{}/account?checkout_complete=1",
358 app.config.zed_dot_dev_url()
359 );
360
361 let checkout_session_url = match body.product {
362 Some(ProductCode::ZedPro) => {
363 stripe_billing
364 .checkout_with_price(
365 app.config.zed_pro_price_id()?,
366 customer_id,
367 &user.github_login,
368 &success_url,
369 )
370 .await?
371 }
372 Some(ProductCode::ZedProTrial) => {
373 if let Some(existing_billing_customer) = &existing_billing_customer {
374 if existing_billing_customer.trial_started_at.is_some() {
375 return Err(Error::http(
376 StatusCode::FORBIDDEN,
377 "user already used free trial".into(),
378 ));
379 }
380 }
381
382 stripe_billing
383 .checkout_with_zed_pro_trial(
384 app.config.zed_pro_price_id()?,
385 customer_id,
386 &user.github_login,
387 &success_url,
388 )
389 .await?
390 }
391 None => {
392 let default_model = llm_db.model(
393 zed_llm_client::LanguageModelProvider::Anthropic,
394 "claude-3-7-sonnet",
395 )?;
396 let stripe_model = stripe_billing.register_model(default_model).await?;
397 stripe_billing
398 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
399 .await?
400 }
401 };
402
403 Ok(Json(CreateBillingSubscriptionResponse {
404 checkout_session_url,
405 }))
406}
407
408#[derive(Debug, PartialEq, Deserialize)]
409#[serde(rename_all = "snake_case")]
410enum ManageSubscriptionIntent {
411 /// The user intends to manage their subscription.
412 ///
413 /// This will open the Stripe billing portal without putting the user in a specific flow.
414 ManageSubscription,
415 /// The user intends to upgrade to Zed Pro.
416 UpgradeToPro,
417 /// The user intends to cancel their subscription.
418 Cancel,
419 /// The user intends to stop the cancellation of their subscription.
420 StopCancellation,
421}
422
423#[derive(Debug, Deserialize)]
424struct ManageBillingSubscriptionBody {
425 github_user_id: i32,
426 intent: ManageSubscriptionIntent,
427 /// The ID of the subscription to manage.
428 subscription_id: BillingSubscriptionId,
429}
430
431#[derive(Debug, Serialize)]
432struct ManageBillingSubscriptionResponse {
433 billing_portal_session_url: Option<String>,
434}
435
436/// Initiates a Stripe customer portal session for managing a billing subscription.
437async fn manage_billing_subscription(
438 Extension(app): Extension<Arc<AppState>>,
439 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
440) -> Result<Json<ManageBillingSubscriptionResponse>> {
441 let user = app
442 .db
443 .get_user_by_github_user_id(body.github_user_id)
444 .await?
445 .ok_or_else(|| anyhow!("user not found"))?;
446
447 let Some(stripe_client) = app.stripe_client.clone() else {
448 log::error!("failed to retrieve Stripe client");
449 Err(Error::http(
450 StatusCode::NOT_IMPLEMENTED,
451 "not supported".into(),
452 ))?
453 };
454
455 let customer = app
456 .db
457 .get_billing_customer_by_user_id(user.id)
458 .await?
459 .ok_or_else(|| anyhow!("billing customer not found"))?;
460 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
461 .context("failed to parse customer ID")?;
462
463 let subscription = app
464 .db
465 .get_billing_subscription_by_id(body.subscription_id)
466 .await?
467 .ok_or_else(|| anyhow!("subscription not found"))?;
468 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
469 .context("failed to parse subscription ID")?;
470
471 if body.intent == ManageSubscriptionIntent::StopCancellation {
472 let updated_stripe_subscription = Subscription::update(
473 &stripe_client,
474 &subscription_id,
475 stripe::UpdateSubscription {
476 cancel_at_period_end: Some(false),
477 ..Default::default()
478 },
479 )
480 .await?;
481
482 app.db
483 .update_billing_subscription(
484 subscription.id,
485 &UpdateBillingSubscriptionParams {
486 stripe_cancel_at: ActiveValue::set(
487 updated_stripe_subscription
488 .cancel_at
489 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
490 .map(|time| time.naive_utc()),
491 ),
492 ..Default::default()
493 },
494 )
495 .await?;
496
497 return Ok(Json(ManageBillingSubscriptionResponse {
498 billing_portal_session_url: None,
499 }));
500 }
501
502 let flow = match body.intent {
503 ManageSubscriptionIntent::ManageSubscription => None,
504 ManageSubscriptionIntent::UpgradeToPro => {
505 let zed_pro_price_id = app.config.zed_pro_price_id()?;
506 let zed_free_price_id = app.config.zed_free_price_id()?;
507
508 let stripe_subscription =
509 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
510
511 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
512 && stripe_subscription.items.data.iter().any(|item| {
513 item.price
514 .as_ref()
515 .map_or(false, |price| price.id == zed_pro_price_id)
516 });
517 if is_on_zed_pro_trial {
518 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
519 Subscription::update(
520 &stripe_client,
521 &stripe_subscription.id,
522 stripe::UpdateSubscription {
523 trial_end: Some(stripe::Scheduled::now()),
524 ..Default::default()
525 },
526 )
527 .await?;
528
529 return Ok(Json(ManageBillingSubscriptionResponse {
530 billing_portal_session_url: None,
531 }));
532 }
533
534 let subscription_item_to_update = stripe_subscription
535 .items
536 .data
537 .iter()
538 .find_map(|item| {
539 let price = item.price.as_ref()?;
540
541 if price.id == zed_free_price_id {
542 Some(item.id.clone())
543 } else {
544 None
545 }
546 })
547 .ok_or_else(|| anyhow!("No subscription item to update"))?;
548
549 Some(CreateBillingPortalSessionFlowData {
550 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
551 subscription_update_confirm: Some(
552 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
553 subscription: subscription.stripe_subscription_id,
554 items: vec![
555 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
556 id: subscription_item_to_update.to_string(),
557 price: Some(zed_pro_price_id.to_string()),
558 quantity: Some(1),
559 },
560 ],
561 discounts: None,
562 },
563 ),
564 ..Default::default()
565 })
566 }
567 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
568 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
569 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
570 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
571 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
572 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
573 }),
574 ..Default::default()
575 }),
576 subscription_cancel: Some(
577 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
578 subscription: subscription.stripe_subscription_id,
579 retention: None,
580 },
581 ),
582 ..Default::default()
583 }),
584 ManageSubscriptionIntent::StopCancellation => unreachable!(),
585 };
586
587 let mut params = CreateBillingPortalSession::new(customer_id);
588 params.flow_data = flow;
589 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
590 params.return_url = Some(&return_url);
591
592 let session = BillingPortalSession::create(&stripe_client, params).await?;
593
594 Ok(Json(ManageBillingSubscriptionResponse {
595 billing_portal_session_url: Some(session.url),
596 }))
597}
598
599/// The amount of time we wait in between each poll of Stripe events.
600///
601/// This value should strike a balance between:
602/// 1. Being short enough that we update quickly when something in Stripe changes
603/// 2. Being long enough that we don't eat into our rate limits.
604///
605/// As a point of reference, the Sequin folks say they have this at **500ms**:
606///
607/// > We poll the Stripe /events endpoint every 500ms per account
608/// >
609/// > — https://blog.sequinstream.com/events-not-webhooks/
610const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
611
612/// The maximum number of events to return per page.
613///
614/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
615///
616/// > Limit can range between 1 and 100, and the default is 10.
617const EVENTS_LIMIT_PER_PAGE: u64 = 100;
618
619/// The number of pages consisting entirely of already-processed events that we
620/// will see before we stop retrieving events.
621///
622/// This is used to prevent over-fetching the Stripe events API for events we've
623/// already seen and processed.
624const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
625
626/// Polls the Stripe events API periodically to reconcile the records in our
627/// database with the data in Stripe.
628pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
629 let Some(stripe_client) = app.stripe_client.clone() else {
630 log::warn!("failed to retrieve Stripe client");
631 return;
632 };
633
634 let executor = app.executor.clone();
635 executor.spawn_detached({
636 let executor = executor.clone();
637 async move {
638 loop {
639 poll_stripe_events(&app, &rpc_server, &stripe_client)
640 .await
641 .log_err();
642
643 executor.sleep(POLL_EVENTS_INTERVAL).await;
644 }
645 }
646 });
647}
648
649async fn poll_stripe_events(
650 app: &Arc<AppState>,
651 rpc_server: &Arc<Server>,
652 stripe_client: &stripe::Client,
653) -> anyhow::Result<()> {
654 fn event_type_to_string(event_type: EventType) -> String {
655 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
656 // so we need to unquote it.
657 event_type.to_string().trim_matches('"').to_string()
658 }
659
660 let event_types = [
661 EventType::CustomerCreated,
662 EventType::CustomerUpdated,
663 EventType::CustomerSubscriptionCreated,
664 EventType::CustomerSubscriptionUpdated,
665 EventType::CustomerSubscriptionPaused,
666 EventType::CustomerSubscriptionResumed,
667 EventType::CustomerSubscriptionDeleted,
668 ]
669 .into_iter()
670 .map(event_type_to_string)
671 .collect::<Vec<_>>();
672
673 let mut pages_of_already_processed_events = 0;
674 let mut unprocessed_events = Vec::new();
675
676 log::info!(
677 "Stripe events: starting retrieval for {}",
678 event_types.join(", ")
679 );
680 let mut params = ListEvents::new();
681 params.types = Some(event_types.clone());
682 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
683
684 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
685 .await?
686 .paginate(params);
687
688 loop {
689 let processed_event_ids = {
690 let event_ids = event_pages
691 .page
692 .data
693 .iter()
694 .map(|event| event.id.as_str())
695 .collect::<Vec<_>>();
696 app.db
697 .get_processed_stripe_events_by_event_ids(&event_ids)
698 .await?
699 .into_iter()
700 .map(|event| event.stripe_event_id)
701 .collect::<Vec<_>>()
702 };
703
704 let mut processed_events_in_page = 0;
705 let events_in_page = event_pages.page.data.len();
706 for event in &event_pages.page.data {
707 if processed_event_ids.contains(&event.id.to_string()) {
708 processed_events_in_page += 1;
709 log::debug!("Stripe events: already processed '{}', skipping", event.id);
710 } else {
711 unprocessed_events.push(event.clone());
712 }
713 }
714
715 if processed_events_in_page == events_in_page {
716 pages_of_already_processed_events += 1;
717 }
718
719 if event_pages.page.has_more {
720 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
721 {
722 log::info!(
723 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
724 );
725 break;
726 } else {
727 log::info!("Stripe events: retrieving next page");
728 event_pages = event_pages.next(&stripe_client).await?;
729 }
730 } else {
731 break;
732 }
733 }
734
735 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
736
737 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
738 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
739
740 for event in unprocessed_events {
741 let event_id = event.id.clone();
742 let processed_event_params = CreateProcessedStripeEventParams {
743 stripe_event_id: event.id.to_string(),
744 stripe_event_type: event_type_to_string(event.type_),
745 stripe_event_created_timestamp: event.created,
746 };
747
748 // If the event has happened too far in the past, we don't want to
749 // process it and risk overwriting other more-recent updates.
750 //
751 // 1 day was chosen arbitrarily. This could be made longer or shorter.
752 let one_day = Duration::from_secs(24 * 60 * 60);
753 let a_day_ago = Utc::now() - one_day;
754 if a_day_ago.timestamp() > event.created {
755 log::info!(
756 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
757 event_id
758 );
759 app.db
760 .create_processed_stripe_event(&processed_event_params)
761 .await?;
762
763 return Ok(());
764 }
765
766 let process_result = match event.type_ {
767 EventType::CustomerCreated | EventType::CustomerUpdated => {
768 handle_customer_event(app, stripe_client, event).await
769 }
770 EventType::CustomerSubscriptionCreated
771 | EventType::CustomerSubscriptionUpdated
772 | EventType::CustomerSubscriptionPaused
773 | EventType::CustomerSubscriptionResumed
774 | EventType::CustomerSubscriptionDeleted => {
775 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
776 }
777 _ => Ok(()),
778 };
779
780 if let Some(()) = process_result
781 .with_context(|| format!("failed to process event {event_id} successfully"))
782 .log_err()
783 {
784 app.db
785 .create_processed_stripe_event(&processed_event_params)
786 .await?;
787 }
788 }
789
790 Ok(())
791}
792
793async fn handle_customer_event(
794 app: &Arc<AppState>,
795 _stripe_client: &stripe::Client,
796 event: stripe::Event,
797) -> anyhow::Result<()> {
798 let EventObject::Customer(customer) = event.data.object else {
799 bail!("unexpected event payload for {}", event.id);
800 };
801
802 log::info!("handling Stripe {} event: {}", event.type_, event.id);
803
804 let Some(email) = customer.email else {
805 log::info!("Stripe customer has no email: skipping");
806 return Ok(());
807 };
808
809 let Some(user) = app.db.get_user_by_email(&email).await? else {
810 log::info!("no user found for email: skipping");
811 return Ok(());
812 };
813
814 if let Some(existing_customer) = app
815 .db
816 .get_billing_customer_by_stripe_customer_id(&customer.id)
817 .await?
818 {
819 app.db
820 .update_billing_customer(
821 existing_customer.id,
822 &UpdateBillingCustomerParams {
823 // For now we just leave the information as-is, as it is not
824 // likely to change.
825 ..Default::default()
826 },
827 )
828 .await?;
829 } else {
830 app.db
831 .create_billing_customer(&CreateBillingCustomerParams {
832 user_id: user.id,
833 stripe_customer_id: customer.id.to_string(),
834 })
835 .await?;
836 }
837
838 Ok(())
839}
840
841async fn handle_customer_subscription_event(
842 app: &Arc<AppState>,
843 rpc_server: &Arc<Server>,
844 stripe_client: &stripe::Client,
845 event: stripe::Event,
846) -> anyhow::Result<()> {
847 let EventObject::Subscription(subscription) = event.data.object else {
848 bail!("unexpected event payload for {}", event.id);
849 };
850
851 log::info!("handling Stripe {} event: {}", event.type_, event.id);
852
853 let subscription_kind = maybe!({
854 let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
855 let zed_free_price_id = app.config.zed_free_price_id().ok()?;
856
857 subscription.items.data.iter().find_map(|item| {
858 let price = item.price.as_ref()?;
859
860 if price.id == zed_pro_price_id {
861 Some(if subscription.status == SubscriptionStatus::Trialing {
862 SubscriptionKind::ZedProTrial
863 } else {
864 SubscriptionKind::ZedPro
865 })
866 } else if price.id == zed_free_price_id {
867 Some(SubscriptionKind::ZedFree)
868 } else {
869 None
870 }
871 })
872 });
873
874 let billing_customer =
875 find_or_create_billing_customer(app, stripe_client, subscription.customer)
876 .await?
877 .ok_or_else(|| anyhow!("billing customer not found"))?;
878
879 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
880 if subscription.status == SubscriptionStatus::Trialing {
881 let current_period_start =
882 DateTime::from_timestamp(subscription.current_period_start, 0)
883 .ok_or_else(|| anyhow!("No trial subscription period start"))?;
884
885 app.db
886 .update_billing_customer(
887 billing_customer.id,
888 &UpdateBillingCustomerParams {
889 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
890 ..Default::default()
891 },
892 )
893 .await?;
894 }
895 }
896
897 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
898 && subscription
899 .cancellation_details
900 .as_ref()
901 .and_then(|details| details.reason)
902 .map_or(false, |reason| {
903 reason == CancellationDetailsReason::PaymentFailed
904 });
905
906 if was_canceled_due_to_payment_failure {
907 app.db
908 .update_billing_customer(
909 billing_customer.id,
910 &UpdateBillingCustomerParams {
911 has_overdue_invoices: ActiveValue::set(true),
912 ..Default::default()
913 },
914 )
915 .await?;
916 }
917
918 if let Some(existing_subscription) = app
919 .db
920 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
921 .await?
922 {
923 let llm_db = app
924 .llm_db
925 .clone()
926 .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
927
928 let new_period_start_at =
929 chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
930 .ok_or_else(|| anyhow!("No subscription period start"))?;
931 let new_period_end_at =
932 chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
933 .ok_or_else(|| anyhow!("No subscription period end"))?;
934
935 llm_db
936 .transfer_existing_subscription_usage(
937 billing_customer.user_id,
938 &existing_subscription,
939 subscription_kind,
940 new_period_start_at,
941 new_period_end_at,
942 )
943 .await?;
944
945 app.db
946 .update_billing_subscription(
947 existing_subscription.id,
948 &UpdateBillingSubscriptionParams {
949 billing_customer_id: ActiveValue::set(billing_customer.id),
950 kind: ActiveValue::set(subscription_kind),
951 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
952 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
953 stripe_cancel_at: ActiveValue::set(
954 subscription
955 .cancel_at
956 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
957 .map(|time| time.naive_utc()),
958 ),
959 stripe_cancellation_reason: ActiveValue::set(
960 subscription
961 .cancellation_details
962 .and_then(|details| details.reason)
963 .map(|reason| reason.into()),
964 ),
965 stripe_current_period_start: ActiveValue::set(Some(
966 subscription.current_period_start,
967 )),
968 stripe_current_period_end: ActiveValue::set(Some(
969 subscription.current_period_end,
970 )),
971 },
972 )
973 .await?;
974 } else {
975 // If the user already has an active billing subscription, ignore the
976 // event and return an `Ok` to signal that it was processed
977 // successfully.
978 //
979 // There is the possibility that this could cause us to not create a
980 // subscription in the following scenario:
981 //
982 // 1. User has an active subscription A
983 // 2. User cancels subscription A
984 // 3. User creates a new subscription B
985 // 4. We process the new subscription B before the cancellation of subscription A
986 // 5. User ends up with no subscriptions
987 //
988 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
989 if app
990 .db
991 .has_active_billing_subscription(billing_customer.user_id)
992 .await?
993 {
994 log::info!(
995 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
996 user_id = billing_customer.user_id,
997 subscription_id = subscription.id
998 );
999 return Ok(());
1000 }
1001
1002 app.db
1003 .create_billing_subscription(&CreateBillingSubscriptionParams {
1004 billing_customer_id: billing_customer.id,
1005 kind: subscription_kind,
1006 stripe_subscription_id: subscription.id.to_string(),
1007 stripe_subscription_status: subscription.status.into(),
1008 stripe_cancellation_reason: subscription
1009 .cancellation_details
1010 .and_then(|details| details.reason)
1011 .map(|reason| reason.into()),
1012 stripe_current_period_start: Some(subscription.current_period_start),
1013 stripe_current_period_end: Some(subscription.current_period_end),
1014 })
1015 .await?;
1016 }
1017
1018 // When the user's subscription changes, we want to refresh their LLM tokens
1019 // to either grant/revoke access.
1020 rpc_server
1021 .refresh_llm_tokens_for_user(billing_customer.user_id)
1022 .await;
1023
1024 Ok(())
1025}
1026
1027#[derive(Debug, Deserialize)]
1028struct GetMonthlySpendParams {
1029 github_user_id: i32,
1030}
1031
1032#[derive(Debug, Serialize)]
1033struct GetMonthlySpendResponse {
1034 monthly_free_tier_spend_in_cents: u32,
1035 monthly_free_tier_allowance_in_cents: u32,
1036 monthly_spend_in_cents: u32,
1037}
1038
1039async fn get_monthly_spend(
1040 Extension(app): Extension<Arc<AppState>>,
1041 Query(params): Query<GetMonthlySpendParams>,
1042) -> Result<Json<GetMonthlySpendResponse>> {
1043 let user = app
1044 .db
1045 .get_user_by_github_user_id(params.github_user_id)
1046 .await?
1047 .ok_or_else(|| anyhow!("user not found"))?;
1048
1049 let Some(llm_db) = app.llm_db.clone() else {
1050 return Err(Error::http(
1051 StatusCode::NOT_IMPLEMENTED,
1052 "LLM database not available".into(),
1053 ));
1054 };
1055
1056 let free_tier = user
1057 .custom_llm_monthly_allowance_in_cents
1058 .map(|allowance| Cents(allowance as u32))
1059 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1060
1061 let spending_for_month = llm_db
1062 .get_user_spending_for_month(user.id, Utc::now())
1063 .await?;
1064
1065 let free_tier_spend = Cents::min(spending_for_month, free_tier);
1066 let monthly_spend = spending_for_month.saturating_sub(free_tier);
1067
1068 Ok(Json(GetMonthlySpendResponse {
1069 monthly_free_tier_spend_in_cents: free_tier_spend.0,
1070 monthly_free_tier_allowance_in_cents: free_tier.0,
1071 monthly_spend_in_cents: monthly_spend.0,
1072 }))
1073}
1074
1075#[derive(Debug, Deserialize)]
1076struct GetCurrentUsageParams {
1077 github_user_id: i32,
1078}
1079
1080#[derive(Debug, Serialize)]
1081struct UsageCounts {
1082 pub used: i32,
1083 pub limit: Option<i32>,
1084 pub remaining: Option<i32>,
1085}
1086
1087#[derive(Debug, Serialize)]
1088struct GetCurrentUsageResponse {
1089 pub model_requests: UsageCounts,
1090 pub edit_predictions: UsageCounts,
1091}
1092
1093async fn get_current_usage(
1094 Extension(app): Extension<Arc<AppState>>,
1095 Query(params): Query<GetCurrentUsageParams>,
1096) -> Result<Json<GetCurrentUsageResponse>> {
1097 let user = app
1098 .db
1099 .get_user_by_github_user_id(params.github_user_id)
1100 .await?
1101 .ok_or_else(|| anyhow!("user not found"))?;
1102
1103 let Some(llm_db) = app.llm_db.clone() else {
1104 return Err(Error::http(
1105 StatusCode::NOT_IMPLEMENTED,
1106 "LLM database not available".into(),
1107 ));
1108 };
1109
1110 let empty_usage = GetCurrentUsageResponse {
1111 model_requests: UsageCounts {
1112 used: 0,
1113 limit: Some(0),
1114 remaining: Some(0),
1115 },
1116 edit_predictions: UsageCounts {
1117 used: 0,
1118 limit: Some(0),
1119 remaining: Some(0),
1120 },
1121 };
1122
1123 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1124 return Ok(Json(empty_usage));
1125 };
1126
1127 let subscription_period = maybe!({
1128 let period_start_at = subscription.current_period_start_at()?;
1129 let period_end_at = subscription.current_period_end_at()?;
1130
1131 Some((period_start_at, period_end_at))
1132 });
1133
1134 let Some((period_start_at, period_end_at)) = subscription_period else {
1135 return Ok(Json(empty_usage));
1136 };
1137
1138 let usage = llm_db
1139 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1140 .await?;
1141 let Some(usage) = usage else {
1142 return Ok(Json(empty_usage));
1143 };
1144
1145 let plan = match usage.plan {
1146 SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1147 SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1148 SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1149 };
1150
1151 let model_requests_limit = match plan.model_requests_limit() {
1152 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1153 zed_llm_client::UsageLimit::Unlimited => None,
1154 };
1155 let edit_prediction_limit = match plan.edit_predictions_limit() {
1156 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1157 zed_llm_client::UsageLimit::Unlimited => None,
1158 };
1159
1160 Ok(Json(GetCurrentUsageResponse {
1161 model_requests: UsageCounts {
1162 used: usage.model_requests,
1163 limit: model_requests_limit,
1164 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1165 },
1166 edit_predictions: UsageCounts {
1167 used: usage.edit_predictions,
1168 limit: edit_prediction_limit,
1169 remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1170 },
1171 }))
1172}
1173
1174impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1175 fn from(value: SubscriptionStatus) -> Self {
1176 match value {
1177 SubscriptionStatus::Incomplete => Self::Incomplete,
1178 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1179 SubscriptionStatus::Trialing => Self::Trialing,
1180 SubscriptionStatus::Active => Self::Active,
1181 SubscriptionStatus::PastDue => Self::PastDue,
1182 SubscriptionStatus::Canceled => Self::Canceled,
1183 SubscriptionStatus::Unpaid => Self::Unpaid,
1184 SubscriptionStatus::Paused => Self::Paused,
1185 }
1186 }
1187}
1188
1189impl From<CancellationDetailsReason> for StripeCancellationReason {
1190 fn from(value: CancellationDetailsReason) -> Self {
1191 match value {
1192 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1193 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1194 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1195 }
1196 }
1197}
1198
1199/// Finds or creates a billing customer using the provided customer.
1200async fn find_or_create_billing_customer(
1201 app: &Arc<AppState>,
1202 stripe_client: &stripe::Client,
1203 customer_or_id: Expandable<Customer>,
1204) -> anyhow::Result<Option<billing_customer::Model>> {
1205 let customer_id = match &customer_or_id {
1206 Expandable::Id(id) => id,
1207 Expandable::Object(customer) => customer.id.as_ref(),
1208 };
1209
1210 // If we already have a billing customer record associated with the Stripe customer,
1211 // there's nothing more we need to do.
1212 if let Some(billing_customer) = app
1213 .db
1214 .get_billing_customer_by_stripe_customer_id(customer_id)
1215 .await?
1216 {
1217 return Ok(Some(billing_customer));
1218 }
1219
1220 // If all we have is a customer ID, resolve it to a full customer record by
1221 // hitting the Stripe API.
1222 let customer = match customer_or_id {
1223 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1224 Expandable::Object(customer) => *customer,
1225 };
1226
1227 let Some(email) = customer.email else {
1228 return Ok(None);
1229 };
1230
1231 let Some(user) = app.db.get_user_by_email(&email).await? else {
1232 return Ok(None);
1233 };
1234
1235 let billing_customer = app
1236 .db
1237 .create_billing_customer(&CreateBillingCustomerParams {
1238 user_id: user.id,
1239 stripe_customer_id: customer.id.to_string(),
1240 })
1241 .await?;
1242
1243 Ok(Some(billing_customer))
1244}
1245
1246const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1247
1248pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
1249 let Some(stripe_billing) = app.stripe_billing.clone() else {
1250 log::warn!("failed to retrieve Stripe billing object");
1251 return;
1252 };
1253 let Some(llm_db) = app.llm_db.clone() else {
1254 log::warn!("failed to retrieve LLM database");
1255 return;
1256 };
1257
1258 let executor = app.executor.clone();
1259 executor.spawn_detached({
1260 let executor = executor.clone();
1261 async move {
1262 loop {
1263 sync_with_stripe(&app, &llm_db, &stripe_billing)
1264 .await
1265 .context("failed to sync LLM usage to Stripe")
1266 .trace_err();
1267 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
1268 }
1269 }
1270 });
1271}
1272
1273async fn sync_with_stripe(
1274 app: &Arc<AppState>,
1275 llm_db: &Arc<LlmDatabase>,
1276 stripe_billing: &Arc<StripeBilling>,
1277) -> anyhow::Result<()> {
1278 let events = llm_db.get_billing_events().await?;
1279 let user_ids = events
1280 .iter()
1281 .map(|(event, _)| event.user_id)
1282 .collect::<HashSet<UserId>>();
1283 let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1284
1285 for (event, model) in events {
1286 let Some((stripe_db_customer, stripe_db_subscription)) =
1287 stripe_subscriptions.get(&event.user_id)
1288 else {
1289 tracing::warn!(
1290 user_id = event.user_id.0,
1291 "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1292 );
1293 continue;
1294 };
1295 let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1296 .stripe_subscription_id
1297 .parse()
1298 .context("failed to parse stripe subscription id from db")?;
1299 let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1300 .stripe_customer_id
1301 .parse()
1302 .context("failed to parse stripe customer id from db")?;
1303
1304 let stripe_model = stripe_billing.register_model(&model).await?;
1305 stripe_billing
1306 .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1307 .await?;
1308 stripe_billing
1309 .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
1310 .await?;
1311 llm_db.consume_billing_event(event.id).await?;
1312 }
1313
1314 Ok(())
1315}