1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
30use crate::rpc::{ResultExt as _, Server};
31use crate::{AppState, Cents, Error, Result};
32use crate::{db::UserId, llm::db::LlmDatabase};
33use crate::{
34 db::{
35 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
36 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
37 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
38 },
39 stripe_billing::StripeBilling,
40};
41
42pub fn router() -> Router {
43 Router::new()
44 .route(
45 "/billing/preferences",
46 get(get_billing_preferences).put(update_billing_preferences),
47 )
48 .route(
49 "/billing/subscriptions",
50 get(list_billing_subscriptions).post(create_billing_subscription),
51 )
52 .route(
53 "/billing/subscriptions/manage",
54 post(manage_billing_subscription),
55 )
56 .route("/billing/monthly_spend", get(get_monthly_spend))
57 .route("/billing/usage", get(get_current_usage))
58}
59
60#[derive(Debug, Deserialize)]
61struct GetBillingPreferencesParams {
62 github_user_id: i32,
63}
64
65#[derive(Debug, Serialize)]
66struct BillingPreferencesResponse {
67 max_monthly_llm_usage_spending_in_cents: i32,
68 model_request_overages_enabled: bool,
69 model_request_overages_spend_limit_in_cents: i32,
70}
71
72async fn get_billing_preferences(
73 Extension(app): Extension<Arc<AppState>>,
74 Query(params): Query<GetBillingPreferencesParams>,
75) -> Result<Json<BillingPreferencesResponse>> {
76 let user = app
77 .db
78 .get_user_by_github_user_id(params.github_user_id)
79 .await?
80 .ok_or_else(|| anyhow!("user not found"))?;
81
82 let preferences = app.db.get_billing_preferences(user.id).await?;
83
84 Ok(Json(BillingPreferencesResponse {
85 max_monthly_llm_usage_spending_in_cents: preferences
86 .as_ref()
87 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
88 preferences.max_monthly_llm_usage_spending_in_cents
89 }),
90 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
91 preferences.model_request_overages_enabled
92 }),
93 model_request_overages_spend_limit_in_cents: preferences
94 .as_ref()
95 .map_or(0, |preferences| {
96 preferences.model_request_overages_spend_limit_in_cents
97 }),
98 }))
99}
100
101#[derive(Debug, Deserialize)]
102struct UpdateBillingPreferencesBody {
103 github_user_id: i32,
104 #[serde(default)]
105 max_monthly_llm_usage_spending_in_cents: i32,
106 #[serde(default)]
107 model_request_overages_enabled: bool,
108 #[serde(default)]
109 model_request_overages_spend_limit_in_cents: i32,
110}
111
112async fn update_billing_preferences(
113 Extension(app): Extension<Arc<AppState>>,
114 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
115 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
116) -> Result<Json<BillingPreferencesResponse>> {
117 let user = app
118 .db
119 .get_user_by_github_user_id(body.github_user_id)
120 .await?
121 .ok_or_else(|| anyhow!("user not found"))?;
122
123 let max_monthly_llm_usage_spending_in_cents =
124 body.max_monthly_llm_usage_spending_in_cents.max(0);
125 let model_request_overages_spend_limit_in_cents =
126 body.model_request_overages_spend_limit_in_cents.max(0);
127
128 let billing_preferences =
129 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
130 app.db
131 .update_billing_preferences(
132 user.id,
133 &UpdateBillingPreferencesParams {
134 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
135 max_monthly_llm_usage_spending_in_cents,
136 ),
137 model_request_overages_enabled: ActiveValue::set(
138 body.model_request_overages_enabled,
139 ),
140 model_request_overages_spend_limit_in_cents: ActiveValue::set(
141 model_request_overages_spend_limit_in_cents,
142 ),
143 },
144 )
145 .await?
146 } else {
147 app.db
148 .create_billing_preferences(
149 user.id,
150 &crate::db::CreateBillingPreferencesParams {
151 max_monthly_llm_usage_spending_in_cents,
152 model_request_overages_enabled: body.model_request_overages_enabled,
153 model_request_overages_spend_limit_in_cents,
154 },
155 )
156 .await?
157 };
158
159 SnowflakeRow::new(
160 "Billing Preferences Updated",
161 Some(user.metrics_id),
162 user.admin,
163 None,
164 json!({
165 "user_id": user.id,
166 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
167 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
168 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
169 }),
170 )
171 .write(&app.kinesis_client, &app.config.kinesis_stream)
172 .await
173 .log_err();
174
175 rpc_server.refresh_llm_tokens_for_user(user.id).await;
176
177 Ok(Json(BillingPreferencesResponse {
178 max_monthly_llm_usage_spending_in_cents: billing_preferences
179 .max_monthly_llm_usage_spending_in_cents,
180 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
181 model_request_overages_spend_limit_in_cents: billing_preferences
182 .model_request_overages_spend_limit_in_cents,
183 }))
184}
185
186#[derive(Debug, Deserialize)]
187struct ListBillingSubscriptionsParams {
188 github_user_id: i32,
189}
190
191#[derive(Debug, Serialize)]
192struct BillingSubscriptionJson {
193 id: BillingSubscriptionId,
194 name: String,
195 status: StripeSubscriptionStatus,
196 trial_end_at: Option<String>,
197 cancel_at: Option<String>,
198 /// Whether this subscription can be canceled.
199 is_cancelable: bool,
200}
201
202#[derive(Debug, Serialize)]
203struct ListBillingSubscriptionsResponse {
204 subscriptions: Vec<BillingSubscriptionJson>,
205}
206
207async fn list_billing_subscriptions(
208 Extension(app): Extension<Arc<AppState>>,
209 Query(params): Query<ListBillingSubscriptionsParams>,
210) -> Result<Json<ListBillingSubscriptionsResponse>> {
211 let user = app
212 .db
213 .get_user_by_github_user_id(params.github_user_id)
214 .await?
215 .ok_or_else(|| anyhow!("user not found"))?;
216
217 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
218
219 Ok(Json(ListBillingSubscriptionsResponse {
220 subscriptions: subscriptions
221 .into_iter()
222 .map(|subscription| BillingSubscriptionJson {
223 id: subscription.id,
224 name: match subscription.kind {
225 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
226 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
227 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
228 None => "Zed LLM Usage".to_string(),
229 },
230 status: subscription.stripe_subscription_status,
231 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
232 maybe!({
233 let end_at = subscription.stripe_current_period_end?;
234 let end_at = DateTime::from_timestamp(end_at, 0)?;
235
236 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
237 })
238 } else {
239 None
240 },
241 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
242 cancel_at
243 .and_utc()
244 .to_rfc3339_opts(SecondsFormat::Millis, true)
245 }),
246 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
247 && subscription.stripe_cancel_at.is_none(),
248 })
249 .collect(),
250 }))
251}
252
253#[derive(Debug, Clone, Copy, Deserialize)]
254#[serde(rename_all = "snake_case")]
255enum ProductCode {
256 ZedPro,
257 ZedProTrial,
258}
259
260#[derive(Debug, Deserialize)]
261struct CreateBillingSubscriptionBody {
262 github_user_id: i32,
263 product: Option<ProductCode>,
264}
265
266#[derive(Debug, Serialize)]
267struct CreateBillingSubscriptionResponse {
268 checkout_session_url: String,
269}
270
271/// Initiates a Stripe Checkout session for creating a billing subscription.
272async fn create_billing_subscription(
273 Extension(app): Extension<Arc<AppState>>,
274 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
275) -> Result<Json<CreateBillingSubscriptionResponse>> {
276 let user = app
277 .db
278 .get_user_by_github_user_id(body.github_user_id)
279 .await?
280 .ok_or_else(|| anyhow!("user not found"))?;
281
282 let Some(stripe_client) = app.stripe_client.clone() else {
283 log::error!("failed to retrieve Stripe client");
284 Err(Error::http(
285 StatusCode::NOT_IMPLEMENTED,
286 "not supported".into(),
287 ))?
288 };
289 let Some(stripe_billing) = app.stripe_billing.clone() else {
290 log::error!("failed to retrieve Stripe billing object");
291 Err(Error::http(
292 StatusCode::NOT_IMPLEMENTED,
293 "not supported".into(),
294 ))?
295 };
296 let Some(llm_db) = app.llm_db.clone() else {
297 log::error!("failed to retrieve LLM database");
298 Err(Error::http(
299 StatusCode::NOT_IMPLEMENTED,
300 "not supported".into(),
301 ))?
302 };
303
304 if app.db.has_active_billing_subscription(user.id).await? {
305 return Err(Error::http(
306 StatusCode::CONFLICT,
307 "user already has an active subscription".into(),
308 ));
309 }
310
311 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
312 if let Some(existing_billing_customer) = &existing_billing_customer {
313 if existing_billing_customer.has_overdue_invoices {
314 return Err(Error::http(
315 StatusCode::PAYMENT_REQUIRED,
316 "user has overdue invoices".into(),
317 ));
318 }
319 }
320
321 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
322 CustomerId::from_str(&existing_customer.stripe_customer_id)
323 .context("failed to parse customer ID")?
324 } else {
325 let customer = Customer::create(
326 &stripe_client,
327 CreateCustomer {
328 email: user.email_address.as_deref(),
329 ..Default::default()
330 },
331 )
332 .await?;
333
334 customer.id
335 };
336
337 let success_url = format!(
338 "{}/account?checkout_complete=1",
339 app.config.zed_dot_dev_url()
340 );
341
342 let checkout_session_url = match body.product {
343 Some(ProductCode::ZedPro) => {
344 stripe_billing
345 .checkout_with_price(
346 app.config.zed_pro_price_id()?,
347 customer_id,
348 &user.github_login,
349 &success_url,
350 )
351 .await?
352 }
353 Some(ProductCode::ZedProTrial) => {
354 if let Some(existing_billing_customer) = &existing_billing_customer {
355 if existing_billing_customer.trial_started_at.is_some() {
356 return Err(Error::http(
357 StatusCode::FORBIDDEN,
358 "user already used free trial".into(),
359 ));
360 }
361 }
362
363 stripe_billing
364 .checkout_with_zed_pro_trial(
365 app.config.zed_pro_price_id()?,
366 customer_id,
367 &user.github_login,
368 &success_url,
369 )
370 .await?
371 }
372 None => {
373 let default_model = llm_db.model(
374 zed_llm_client::LanguageModelProvider::Anthropic,
375 "claude-3-7-sonnet",
376 )?;
377 let stripe_model = stripe_billing.register_model(default_model).await?;
378 stripe_billing
379 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
380 .await?
381 }
382 };
383
384 Ok(Json(CreateBillingSubscriptionResponse {
385 checkout_session_url,
386 }))
387}
388
389#[derive(Debug, PartialEq, Deserialize)]
390#[serde(rename_all = "snake_case")]
391enum ManageSubscriptionIntent {
392 /// The user intends to manage their subscription.
393 ///
394 /// This will open the Stripe billing portal without putting the user in a specific flow.
395 ManageSubscription,
396 /// The user intends to upgrade to Zed Pro.
397 UpgradeToPro,
398 /// The user intends to cancel their subscription.
399 Cancel,
400 /// The user intends to stop the cancellation of their subscription.
401 StopCancellation,
402}
403
404#[derive(Debug, Deserialize)]
405struct ManageBillingSubscriptionBody {
406 github_user_id: i32,
407 intent: ManageSubscriptionIntent,
408 /// The ID of the subscription to manage.
409 subscription_id: BillingSubscriptionId,
410}
411
412#[derive(Debug, Serialize)]
413struct ManageBillingSubscriptionResponse {
414 billing_portal_session_url: Option<String>,
415}
416
417/// Initiates a Stripe customer portal session for managing a billing subscription.
418async fn manage_billing_subscription(
419 Extension(app): Extension<Arc<AppState>>,
420 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
421) -> Result<Json<ManageBillingSubscriptionResponse>> {
422 let user = app
423 .db
424 .get_user_by_github_user_id(body.github_user_id)
425 .await?
426 .ok_or_else(|| anyhow!("user not found"))?;
427
428 let Some(stripe_client) = app.stripe_client.clone() else {
429 log::error!("failed to retrieve Stripe client");
430 Err(Error::http(
431 StatusCode::NOT_IMPLEMENTED,
432 "not supported".into(),
433 ))?
434 };
435
436 let customer = app
437 .db
438 .get_billing_customer_by_user_id(user.id)
439 .await?
440 .ok_or_else(|| anyhow!("billing customer not found"))?;
441 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
442 .context("failed to parse customer ID")?;
443
444 let subscription = app
445 .db
446 .get_billing_subscription_by_id(body.subscription_id)
447 .await?
448 .ok_or_else(|| anyhow!("subscription not found"))?;
449 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
450 .context("failed to parse subscription ID")?;
451
452 if body.intent == ManageSubscriptionIntent::StopCancellation {
453 let updated_stripe_subscription = Subscription::update(
454 &stripe_client,
455 &subscription_id,
456 stripe::UpdateSubscription {
457 cancel_at_period_end: Some(false),
458 ..Default::default()
459 },
460 )
461 .await?;
462
463 app.db
464 .update_billing_subscription(
465 subscription.id,
466 &UpdateBillingSubscriptionParams {
467 stripe_cancel_at: ActiveValue::set(
468 updated_stripe_subscription
469 .cancel_at
470 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
471 .map(|time| time.naive_utc()),
472 ),
473 ..Default::default()
474 },
475 )
476 .await?;
477
478 return Ok(Json(ManageBillingSubscriptionResponse {
479 billing_portal_session_url: None,
480 }));
481 }
482
483 let flow = match body.intent {
484 ManageSubscriptionIntent::ManageSubscription => None,
485 ManageSubscriptionIntent::UpgradeToPro => {
486 let zed_pro_price_id = app.config.zed_pro_price_id()?;
487 let zed_free_price_id = app.config.zed_free_price_id()?;
488
489 let stripe_subscription =
490 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
491
492 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
493 && stripe_subscription.items.data.iter().any(|item| {
494 item.price
495 .as_ref()
496 .map_or(false, |price| price.id == zed_pro_price_id)
497 });
498 if is_on_zed_pro_trial {
499 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
500 Subscription::update(
501 &stripe_client,
502 &stripe_subscription.id,
503 stripe::UpdateSubscription {
504 trial_end: Some(stripe::Scheduled::now()),
505 ..Default::default()
506 },
507 )
508 .await?;
509
510 return Ok(Json(ManageBillingSubscriptionResponse {
511 billing_portal_session_url: None,
512 }));
513 }
514
515 let subscription_item_to_update = stripe_subscription
516 .items
517 .data
518 .iter()
519 .find_map(|item| {
520 let price = item.price.as_ref()?;
521
522 if price.id == zed_free_price_id {
523 Some(item.id.clone())
524 } else {
525 None
526 }
527 })
528 .ok_or_else(|| anyhow!("No subscription item to update"))?;
529
530 Some(CreateBillingPortalSessionFlowData {
531 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
532 subscription_update_confirm: Some(
533 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
534 subscription: subscription.stripe_subscription_id,
535 items: vec![
536 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
537 id: subscription_item_to_update.to_string(),
538 price: Some(zed_pro_price_id.to_string()),
539 quantity: Some(1),
540 },
541 ],
542 discounts: None,
543 },
544 ),
545 ..Default::default()
546 })
547 }
548 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
549 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
550 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
551 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
552 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
553 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
554 }),
555 ..Default::default()
556 }),
557 subscription_cancel: Some(
558 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
559 subscription: subscription.stripe_subscription_id,
560 retention: None,
561 },
562 ),
563 ..Default::default()
564 }),
565 ManageSubscriptionIntent::StopCancellation => unreachable!(),
566 };
567
568 let mut params = CreateBillingPortalSession::new(customer_id);
569 params.flow_data = flow;
570 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
571 params.return_url = Some(&return_url);
572
573 let session = BillingPortalSession::create(&stripe_client, params).await?;
574
575 Ok(Json(ManageBillingSubscriptionResponse {
576 billing_portal_session_url: Some(session.url),
577 }))
578}
579
580/// The amount of time we wait in between each poll of Stripe events.
581///
582/// This value should strike a balance between:
583/// 1. Being short enough that we update quickly when something in Stripe changes
584/// 2. Being long enough that we don't eat into our rate limits.
585///
586/// As a point of reference, the Sequin folks say they have this at **500ms**:
587///
588/// > We poll the Stripe /events endpoint every 500ms per account
589/// >
590/// > — https://blog.sequinstream.com/events-not-webhooks/
591const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
592
593/// The maximum number of events to return per page.
594///
595/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
596///
597/// > Limit can range between 1 and 100, and the default is 10.
598const EVENTS_LIMIT_PER_PAGE: u64 = 100;
599
600/// The number of pages consisting entirely of already-processed events that we
601/// will see before we stop retrieving events.
602///
603/// This is used to prevent over-fetching the Stripe events API for events we've
604/// already seen and processed.
605const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
606
607/// Polls the Stripe events API periodically to reconcile the records in our
608/// database with the data in Stripe.
609pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
610 let Some(stripe_client) = app.stripe_client.clone() else {
611 log::warn!("failed to retrieve Stripe client");
612 return;
613 };
614
615 let executor = app.executor.clone();
616 executor.spawn_detached({
617 let executor = executor.clone();
618 async move {
619 loop {
620 poll_stripe_events(&app, &rpc_server, &stripe_client)
621 .await
622 .log_err();
623
624 executor.sleep(POLL_EVENTS_INTERVAL).await;
625 }
626 }
627 });
628}
629
630async fn poll_stripe_events(
631 app: &Arc<AppState>,
632 rpc_server: &Arc<Server>,
633 stripe_client: &stripe::Client,
634) -> anyhow::Result<()> {
635 fn event_type_to_string(event_type: EventType) -> String {
636 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
637 // so we need to unquote it.
638 event_type.to_string().trim_matches('"').to_string()
639 }
640
641 let event_types = [
642 EventType::CustomerCreated,
643 EventType::CustomerUpdated,
644 EventType::CustomerSubscriptionCreated,
645 EventType::CustomerSubscriptionUpdated,
646 EventType::CustomerSubscriptionPaused,
647 EventType::CustomerSubscriptionResumed,
648 EventType::CustomerSubscriptionDeleted,
649 ]
650 .into_iter()
651 .map(event_type_to_string)
652 .collect::<Vec<_>>();
653
654 let mut pages_of_already_processed_events = 0;
655 let mut unprocessed_events = Vec::new();
656
657 log::info!(
658 "Stripe events: starting retrieval for {}",
659 event_types.join(", ")
660 );
661 let mut params = ListEvents::new();
662 params.types = Some(event_types.clone());
663 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
664
665 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
666 .await?
667 .paginate(params);
668
669 loop {
670 let processed_event_ids = {
671 let event_ids = event_pages
672 .page
673 .data
674 .iter()
675 .map(|event| event.id.as_str())
676 .collect::<Vec<_>>();
677 app.db
678 .get_processed_stripe_events_by_event_ids(&event_ids)
679 .await?
680 .into_iter()
681 .map(|event| event.stripe_event_id)
682 .collect::<Vec<_>>()
683 };
684
685 let mut processed_events_in_page = 0;
686 let events_in_page = event_pages.page.data.len();
687 for event in &event_pages.page.data {
688 if processed_event_ids.contains(&event.id.to_string()) {
689 processed_events_in_page += 1;
690 log::debug!("Stripe events: already processed '{}', skipping", event.id);
691 } else {
692 unprocessed_events.push(event.clone());
693 }
694 }
695
696 if processed_events_in_page == events_in_page {
697 pages_of_already_processed_events += 1;
698 }
699
700 if event_pages.page.has_more {
701 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
702 {
703 log::info!(
704 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
705 );
706 break;
707 } else {
708 log::info!("Stripe events: retrieving next page");
709 event_pages = event_pages.next(&stripe_client).await?;
710 }
711 } else {
712 break;
713 }
714 }
715
716 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
717
718 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
719 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
720
721 for event in unprocessed_events {
722 let event_id = event.id.clone();
723 let processed_event_params = CreateProcessedStripeEventParams {
724 stripe_event_id: event.id.to_string(),
725 stripe_event_type: event_type_to_string(event.type_),
726 stripe_event_created_timestamp: event.created,
727 };
728
729 // If the event has happened too far in the past, we don't want to
730 // process it and risk overwriting other more-recent updates.
731 //
732 // 1 day was chosen arbitrarily. This could be made longer or shorter.
733 let one_day = Duration::from_secs(24 * 60 * 60);
734 let a_day_ago = Utc::now() - one_day;
735 if a_day_ago.timestamp() > event.created {
736 log::info!(
737 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
738 event_id
739 );
740 app.db
741 .create_processed_stripe_event(&processed_event_params)
742 .await?;
743
744 return Ok(());
745 }
746
747 let process_result = match event.type_ {
748 EventType::CustomerCreated | EventType::CustomerUpdated => {
749 handle_customer_event(app, stripe_client, event).await
750 }
751 EventType::CustomerSubscriptionCreated
752 | EventType::CustomerSubscriptionUpdated
753 | EventType::CustomerSubscriptionPaused
754 | EventType::CustomerSubscriptionResumed
755 | EventType::CustomerSubscriptionDeleted => {
756 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
757 }
758 _ => Ok(()),
759 };
760
761 if let Some(()) = process_result
762 .with_context(|| format!("failed to process event {event_id} successfully"))
763 .log_err()
764 {
765 app.db
766 .create_processed_stripe_event(&processed_event_params)
767 .await?;
768 }
769 }
770
771 Ok(())
772}
773
774async fn handle_customer_event(
775 app: &Arc<AppState>,
776 _stripe_client: &stripe::Client,
777 event: stripe::Event,
778) -> anyhow::Result<()> {
779 let EventObject::Customer(customer) = event.data.object else {
780 bail!("unexpected event payload for {}", event.id);
781 };
782
783 log::info!("handling Stripe {} event: {}", event.type_, event.id);
784
785 let Some(email) = customer.email else {
786 log::info!("Stripe customer has no email: skipping");
787 return Ok(());
788 };
789
790 let Some(user) = app.db.get_user_by_email(&email).await? else {
791 log::info!("no user found for email: skipping");
792 return Ok(());
793 };
794
795 if let Some(existing_customer) = app
796 .db
797 .get_billing_customer_by_stripe_customer_id(&customer.id)
798 .await?
799 {
800 app.db
801 .update_billing_customer(
802 existing_customer.id,
803 &UpdateBillingCustomerParams {
804 // For now we just leave the information as-is, as it is not
805 // likely to change.
806 ..Default::default()
807 },
808 )
809 .await?;
810 } else {
811 app.db
812 .create_billing_customer(&CreateBillingCustomerParams {
813 user_id: user.id,
814 stripe_customer_id: customer.id.to_string(),
815 })
816 .await?;
817 }
818
819 Ok(())
820}
821
822async fn handle_customer_subscription_event(
823 app: &Arc<AppState>,
824 rpc_server: &Arc<Server>,
825 stripe_client: &stripe::Client,
826 event: stripe::Event,
827) -> anyhow::Result<()> {
828 let EventObject::Subscription(subscription) = event.data.object else {
829 bail!("unexpected event payload for {}", event.id);
830 };
831
832 log::info!("handling Stripe {} event: {}", event.type_, event.id);
833
834 let subscription_kind = maybe!({
835 let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
836 let zed_free_price_id = app.config.zed_free_price_id().ok()?;
837
838 subscription.items.data.iter().find_map(|item| {
839 let price = item.price.as_ref()?;
840
841 if price.id == zed_pro_price_id {
842 Some(if subscription.status == SubscriptionStatus::Trialing {
843 SubscriptionKind::ZedProTrial
844 } else {
845 SubscriptionKind::ZedPro
846 })
847 } else if price.id == zed_free_price_id {
848 Some(SubscriptionKind::ZedFree)
849 } else {
850 None
851 }
852 })
853 });
854
855 let billing_customer =
856 find_or_create_billing_customer(app, stripe_client, subscription.customer)
857 .await?
858 .ok_or_else(|| anyhow!("billing customer not found"))?;
859
860 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
861 if subscription.status == SubscriptionStatus::Trialing {
862 let current_period_start =
863 DateTime::from_timestamp(subscription.current_period_start, 0)
864 .ok_or_else(|| anyhow!("No trial subscription period start"))?;
865
866 app.db
867 .update_billing_customer(
868 billing_customer.id,
869 &UpdateBillingCustomerParams {
870 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
871 ..Default::default()
872 },
873 )
874 .await?;
875 }
876 }
877
878 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
879 && subscription
880 .cancellation_details
881 .as_ref()
882 .and_then(|details| details.reason)
883 .map_or(false, |reason| {
884 reason == CancellationDetailsReason::PaymentFailed
885 });
886
887 if was_canceled_due_to_payment_failure {
888 app.db
889 .update_billing_customer(
890 billing_customer.id,
891 &UpdateBillingCustomerParams {
892 has_overdue_invoices: ActiveValue::set(true),
893 ..Default::default()
894 },
895 )
896 .await?;
897 }
898
899 if let Some(existing_subscription) = app
900 .db
901 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
902 .await?
903 {
904 let llm_db = app
905 .llm_db
906 .clone()
907 .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
908
909 let new_period_start_at =
910 chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
911 .ok_or_else(|| anyhow!("No subscription period start"))?;
912 let new_period_end_at =
913 chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
914 .ok_or_else(|| anyhow!("No subscription period end"))?;
915
916 llm_db
917 .transfer_existing_subscription_usage(
918 billing_customer.user_id,
919 &existing_subscription,
920 subscription_kind,
921 new_period_start_at,
922 new_period_end_at,
923 )
924 .await?;
925
926 app.db
927 .update_billing_subscription(
928 existing_subscription.id,
929 &UpdateBillingSubscriptionParams {
930 billing_customer_id: ActiveValue::set(billing_customer.id),
931 kind: ActiveValue::set(subscription_kind),
932 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
933 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
934 stripe_cancel_at: ActiveValue::set(
935 subscription
936 .cancel_at
937 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
938 .map(|time| time.naive_utc()),
939 ),
940 stripe_cancellation_reason: ActiveValue::set(
941 subscription
942 .cancellation_details
943 .and_then(|details| details.reason)
944 .map(|reason| reason.into()),
945 ),
946 stripe_current_period_start: ActiveValue::set(Some(
947 subscription.current_period_start,
948 )),
949 stripe_current_period_end: ActiveValue::set(Some(
950 subscription.current_period_end,
951 )),
952 },
953 )
954 .await?;
955 } else {
956 // If the user already has an active billing subscription, ignore the
957 // event and return an `Ok` to signal that it was processed
958 // successfully.
959 //
960 // There is the possibility that this could cause us to not create a
961 // subscription in the following scenario:
962 //
963 // 1. User has an active subscription A
964 // 2. User cancels subscription A
965 // 3. User creates a new subscription B
966 // 4. We process the new subscription B before the cancellation of subscription A
967 // 5. User ends up with no subscriptions
968 //
969 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
970 if app
971 .db
972 .has_active_billing_subscription(billing_customer.user_id)
973 .await?
974 {
975 log::info!(
976 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
977 user_id = billing_customer.user_id,
978 subscription_id = subscription.id
979 );
980 return Ok(());
981 }
982
983 app.db
984 .create_billing_subscription(&CreateBillingSubscriptionParams {
985 billing_customer_id: billing_customer.id,
986 kind: subscription_kind,
987 stripe_subscription_id: subscription.id.to_string(),
988 stripe_subscription_status: subscription.status.into(),
989 stripe_cancellation_reason: subscription
990 .cancellation_details
991 .and_then(|details| details.reason)
992 .map(|reason| reason.into()),
993 stripe_current_period_start: Some(subscription.current_period_start),
994 stripe_current_period_end: Some(subscription.current_period_end),
995 })
996 .await?;
997 }
998
999 // When the user's subscription changes, we want to refresh their LLM tokens
1000 // to either grant/revoke access.
1001 rpc_server
1002 .refresh_llm_tokens_for_user(billing_customer.user_id)
1003 .await;
1004
1005 Ok(())
1006}
1007
1008#[derive(Debug, Deserialize)]
1009struct GetMonthlySpendParams {
1010 github_user_id: i32,
1011}
1012
1013#[derive(Debug, Serialize)]
1014struct GetMonthlySpendResponse {
1015 monthly_free_tier_spend_in_cents: u32,
1016 monthly_free_tier_allowance_in_cents: u32,
1017 monthly_spend_in_cents: u32,
1018}
1019
1020async fn get_monthly_spend(
1021 Extension(app): Extension<Arc<AppState>>,
1022 Query(params): Query<GetMonthlySpendParams>,
1023) -> Result<Json<GetMonthlySpendResponse>> {
1024 let user = app
1025 .db
1026 .get_user_by_github_user_id(params.github_user_id)
1027 .await?
1028 .ok_or_else(|| anyhow!("user not found"))?;
1029
1030 let Some(llm_db) = app.llm_db.clone() else {
1031 return Err(Error::http(
1032 StatusCode::NOT_IMPLEMENTED,
1033 "LLM database not available".into(),
1034 ));
1035 };
1036
1037 let free_tier = user
1038 .custom_llm_monthly_allowance_in_cents
1039 .map(|allowance| Cents(allowance as u32))
1040 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1041
1042 let spending_for_month = llm_db
1043 .get_user_spending_for_month(user.id, Utc::now())
1044 .await?;
1045
1046 let free_tier_spend = Cents::min(spending_for_month, free_tier);
1047 let monthly_spend = spending_for_month.saturating_sub(free_tier);
1048
1049 Ok(Json(GetMonthlySpendResponse {
1050 monthly_free_tier_spend_in_cents: free_tier_spend.0,
1051 monthly_free_tier_allowance_in_cents: free_tier.0,
1052 monthly_spend_in_cents: monthly_spend.0,
1053 }))
1054}
1055
1056#[derive(Debug, Deserialize)]
1057struct GetCurrentUsageParams {
1058 github_user_id: i32,
1059}
1060
1061#[derive(Debug, Serialize)]
1062struct UsageCounts {
1063 pub used: i32,
1064 pub limit: Option<i32>,
1065 pub remaining: Option<i32>,
1066}
1067
1068#[derive(Debug, Serialize)]
1069struct GetCurrentUsageResponse {
1070 pub model_requests: UsageCounts,
1071 pub edit_predictions: UsageCounts,
1072}
1073
1074async fn get_current_usage(
1075 Extension(app): Extension<Arc<AppState>>,
1076 Query(params): Query<GetCurrentUsageParams>,
1077) -> Result<Json<GetCurrentUsageResponse>> {
1078 let user = app
1079 .db
1080 .get_user_by_github_user_id(params.github_user_id)
1081 .await?
1082 .ok_or_else(|| anyhow!("user not found"))?;
1083
1084 let Some(llm_db) = app.llm_db.clone() else {
1085 return Err(Error::http(
1086 StatusCode::NOT_IMPLEMENTED,
1087 "LLM database not available".into(),
1088 ));
1089 };
1090
1091 let empty_usage = GetCurrentUsageResponse {
1092 model_requests: UsageCounts {
1093 used: 0,
1094 limit: Some(0),
1095 remaining: Some(0),
1096 },
1097 edit_predictions: UsageCounts {
1098 used: 0,
1099 limit: Some(0),
1100 remaining: Some(0),
1101 },
1102 };
1103
1104 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1105 return Ok(Json(empty_usage));
1106 };
1107
1108 let subscription_period = maybe!({
1109 let period_start_at = subscription.current_period_start_at()?;
1110 let period_end_at = subscription.current_period_end_at()?;
1111
1112 Some((period_start_at, period_end_at))
1113 });
1114
1115 let Some((period_start_at, period_end_at)) = subscription_period else {
1116 return Ok(Json(empty_usage));
1117 };
1118
1119 let usage = llm_db
1120 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1121 .await?;
1122 let Some(usage) = usage else {
1123 return Ok(Json(empty_usage));
1124 };
1125
1126 let plan = match usage.plan {
1127 SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1128 SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1129 SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1130 };
1131
1132 let model_requests_limit = match plan.model_requests_limit() {
1133 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1134 zed_llm_client::UsageLimit::Unlimited => None,
1135 };
1136 let edit_prediction_limit = match plan.edit_predictions_limit() {
1137 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1138 zed_llm_client::UsageLimit::Unlimited => None,
1139 };
1140
1141 Ok(Json(GetCurrentUsageResponse {
1142 model_requests: UsageCounts {
1143 used: usage.model_requests,
1144 limit: model_requests_limit,
1145 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1146 },
1147 edit_predictions: UsageCounts {
1148 used: usage.edit_predictions,
1149 limit: edit_prediction_limit,
1150 remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1151 },
1152 }))
1153}
1154
1155impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1156 fn from(value: SubscriptionStatus) -> Self {
1157 match value {
1158 SubscriptionStatus::Incomplete => Self::Incomplete,
1159 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1160 SubscriptionStatus::Trialing => Self::Trialing,
1161 SubscriptionStatus::Active => Self::Active,
1162 SubscriptionStatus::PastDue => Self::PastDue,
1163 SubscriptionStatus::Canceled => Self::Canceled,
1164 SubscriptionStatus::Unpaid => Self::Unpaid,
1165 SubscriptionStatus::Paused => Self::Paused,
1166 }
1167 }
1168}
1169
1170impl From<CancellationDetailsReason> for StripeCancellationReason {
1171 fn from(value: CancellationDetailsReason) -> Self {
1172 match value {
1173 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1174 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1175 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1176 }
1177 }
1178}
1179
1180/// Finds or creates a billing customer using the provided customer.
1181async fn find_or_create_billing_customer(
1182 app: &Arc<AppState>,
1183 stripe_client: &stripe::Client,
1184 customer_or_id: Expandable<Customer>,
1185) -> anyhow::Result<Option<billing_customer::Model>> {
1186 let customer_id = match &customer_or_id {
1187 Expandable::Id(id) => id,
1188 Expandable::Object(customer) => customer.id.as_ref(),
1189 };
1190
1191 // If we already have a billing customer record associated with the Stripe customer,
1192 // there's nothing more we need to do.
1193 if let Some(billing_customer) = app
1194 .db
1195 .get_billing_customer_by_stripe_customer_id(customer_id)
1196 .await?
1197 {
1198 return Ok(Some(billing_customer));
1199 }
1200
1201 // If all we have is a customer ID, resolve it to a full customer record by
1202 // hitting the Stripe API.
1203 let customer = match customer_or_id {
1204 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1205 Expandable::Object(customer) => *customer,
1206 };
1207
1208 let Some(email) = customer.email else {
1209 return Ok(None);
1210 };
1211
1212 let Some(user) = app.db.get_user_by_email(&email).await? else {
1213 return Ok(None);
1214 };
1215
1216 let billing_customer = app
1217 .db
1218 .create_billing_customer(&CreateBillingCustomerParams {
1219 user_id: user.id,
1220 stripe_customer_id: customer.id.to_string(),
1221 })
1222 .await?;
1223
1224 Ok(Some(billing_customer))
1225}
1226
1227const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1228
1229pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
1230 let Some(stripe_billing) = app.stripe_billing.clone() else {
1231 log::warn!("failed to retrieve Stripe billing object");
1232 return;
1233 };
1234 let Some(llm_db) = app.llm_db.clone() else {
1235 log::warn!("failed to retrieve LLM database");
1236 return;
1237 };
1238
1239 let executor = app.executor.clone();
1240 executor.spawn_detached({
1241 let executor = executor.clone();
1242 async move {
1243 loop {
1244 sync_with_stripe(&app, &llm_db, &stripe_billing)
1245 .await
1246 .context("failed to sync LLM usage to Stripe")
1247 .trace_err();
1248 executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
1249 }
1250 }
1251 });
1252}
1253
1254async fn sync_with_stripe(
1255 app: &Arc<AppState>,
1256 llm_db: &Arc<LlmDatabase>,
1257 stripe_billing: &Arc<StripeBilling>,
1258) -> anyhow::Result<()> {
1259 let events = llm_db.get_billing_events().await?;
1260 let user_ids = events
1261 .iter()
1262 .map(|(event, _)| event.user_id)
1263 .collect::<HashSet<UserId>>();
1264 let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1265
1266 for (event, model) in events {
1267 let Some((stripe_db_customer, stripe_db_subscription)) =
1268 stripe_subscriptions.get(&event.user_id)
1269 else {
1270 tracing::warn!(
1271 user_id = event.user_id.0,
1272 "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1273 );
1274 continue;
1275 };
1276 let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1277 .stripe_subscription_id
1278 .parse()
1279 .context("failed to parse stripe subscription id from db")?;
1280 let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1281 .stripe_customer_id
1282 .parse()
1283 .context("failed to parse stripe customer id from db")?;
1284
1285 let stripe_model = stripe_billing.register_model(&model).await?;
1286 stripe_billing
1287 .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1288 .await?;
1289 stripe_billing
1290 .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
1291 .await?;
1292 llm_db.consume_billing_event(event.id).await?;
1293 }
1294
1295 Ok(())
1296}