1use anyhow::{Context, anyhow, bail};
2use axum::{
3 Extension, Json, Router,
4 extract::{self, Query},
5 routing::{get, post},
6};
7use chrono::{DateTime, SecondsFormat, Utc};
8use collections::HashSet;
9use reqwest::StatusCode;
10use sea_orm::ActiveValue;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::{str::FromStr, sync::Arc, time::Duration};
14use stripe::{
15 BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
19 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
20 CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
21 EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
22};
23use util::{ResultExt, maybe};
24
25use crate::api::events::SnowflakeRow;
26use crate::db::billing_subscription::{
27 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
28};
29use crate::llm::db::subscription_usage_meter::CompletionMode;
30use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
31use crate::rpc::{ResultExt as _, Server};
32use crate::{AppState, Cents, Error, Result};
33use crate::{db::UserId, llm::db::LlmDatabase};
34use crate::{
35 db::{
36 BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
37 CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
38 UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
39 },
40 stripe_billing::StripeBilling,
41};
42
43pub fn router() -> Router {
44 Router::new()
45 .route(
46 "/billing/preferences",
47 get(get_billing_preferences).put(update_billing_preferences),
48 )
49 .route(
50 "/billing/subscriptions",
51 get(list_billing_subscriptions).post(create_billing_subscription),
52 )
53 .route(
54 "/billing/subscriptions/manage",
55 post(manage_billing_subscription),
56 )
57 .route("/billing/monthly_spend", get(get_monthly_spend))
58 .route("/billing/usage", get(get_current_usage))
59}
60
61#[derive(Debug, Deserialize)]
62struct GetBillingPreferencesParams {
63 github_user_id: i32,
64}
65
66#[derive(Debug, Serialize)]
67struct BillingPreferencesResponse {
68 max_monthly_llm_usage_spending_in_cents: i32,
69 model_request_overages_enabled: bool,
70 model_request_overages_spend_limit_in_cents: i32,
71}
72
73async fn get_billing_preferences(
74 Extension(app): Extension<Arc<AppState>>,
75 Query(params): Query<GetBillingPreferencesParams>,
76) -> Result<Json<BillingPreferencesResponse>> {
77 let user = app
78 .db
79 .get_user_by_github_user_id(params.github_user_id)
80 .await?
81 .ok_or_else(|| anyhow!("user not found"))?;
82
83 let preferences = app.db.get_billing_preferences(user.id).await?;
84
85 Ok(Json(BillingPreferencesResponse {
86 max_monthly_llm_usage_spending_in_cents: preferences
87 .as_ref()
88 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
89 preferences.max_monthly_llm_usage_spending_in_cents
90 }),
91 model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
92 preferences.model_request_overages_enabled
93 }),
94 model_request_overages_spend_limit_in_cents: preferences
95 .as_ref()
96 .map_or(0, |preferences| {
97 preferences.model_request_overages_spend_limit_in_cents
98 }),
99 }))
100}
101
102#[derive(Debug, Deserialize)]
103struct UpdateBillingPreferencesBody {
104 github_user_id: i32,
105 #[serde(default)]
106 max_monthly_llm_usage_spending_in_cents: i32,
107 #[serde(default)]
108 model_request_overages_enabled: bool,
109 #[serde(default)]
110 model_request_overages_spend_limit_in_cents: i32,
111}
112
113async fn update_billing_preferences(
114 Extension(app): Extension<Arc<AppState>>,
115 Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
116 extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
117) -> Result<Json<BillingPreferencesResponse>> {
118 let user = app
119 .db
120 .get_user_by_github_user_id(body.github_user_id)
121 .await?
122 .ok_or_else(|| anyhow!("user not found"))?;
123
124 let max_monthly_llm_usage_spending_in_cents =
125 body.max_monthly_llm_usage_spending_in_cents.max(0);
126 let model_request_overages_spend_limit_in_cents =
127 body.model_request_overages_spend_limit_in_cents.max(0);
128
129 let billing_preferences =
130 if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
131 app.db
132 .update_billing_preferences(
133 user.id,
134 &UpdateBillingPreferencesParams {
135 max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
136 max_monthly_llm_usage_spending_in_cents,
137 ),
138 model_request_overages_enabled: ActiveValue::set(
139 body.model_request_overages_enabled,
140 ),
141 model_request_overages_spend_limit_in_cents: ActiveValue::set(
142 model_request_overages_spend_limit_in_cents,
143 ),
144 },
145 )
146 .await?
147 } else {
148 app.db
149 .create_billing_preferences(
150 user.id,
151 &crate::db::CreateBillingPreferencesParams {
152 max_monthly_llm_usage_spending_in_cents,
153 model_request_overages_enabled: body.model_request_overages_enabled,
154 model_request_overages_spend_limit_in_cents,
155 },
156 )
157 .await?
158 };
159
160 SnowflakeRow::new(
161 "Billing Preferences Updated",
162 Some(user.metrics_id),
163 user.admin,
164 None,
165 json!({
166 "user_id": user.id,
167 "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
168 "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
169 "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
170 }),
171 )
172 .write(&app.kinesis_client, &app.config.kinesis_stream)
173 .await
174 .log_err();
175
176 rpc_server.refresh_llm_tokens_for_user(user.id).await;
177
178 Ok(Json(BillingPreferencesResponse {
179 max_monthly_llm_usage_spending_in_cents: billing_preferences
180 .max_monthly_llm_usage_spending_in_cents,
181 model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
182 model_request_overages_spend_limit_in_cents: billing_preferences
183 .model_request_overages_spend_limit_in_cents,
184 }))
185}
186
187#[derive(Debug, Deserialize)]
188struct ListBillingSubscriptionsParams {
189 github_user_id: i32,
190}
191
192#[derive(Debug, Serialize)]
193struct BillingSubscriptionJson {
194 id: BillingSubscriptionId,
195 name: String,
196 status: StripeSubscriptionStatus,
197 trial_end_at: Option<String>,
198 cancel_at: Option<String>,
199 /// Whether this subscription can be canceled.
200 is_cancelable: bool,
201}
202
203#[derive(Debug, Serialize)]
204struct ListBillingSubscriptionsResponse {
205 subscriptions: Vec<BillingSubscriptionJson>,
206}
207
208async fn list_billing_subscriptions(
209 Extension(app): Extension<Arc<AppState>>,
210 Query(params): Query<ListBillingSubscriptionsParams>,
211) -> Result<Json<ListBillingSubscriptionsResponse>> {
212 let user = app
213 .db
214 .get_user_by_github_user_id(params.github_user_id)
215 .await?
216 .ok_or_else(|| anyhow!("user not found"))?;
217
218 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
219
220 Ok(Json(ListBillingSubscriptionsResponse {
221 subscriptions: subscriptions
222 .into_iter()
223 .map(|subscription| BillingSubscriptionJson {
224 id: subscription.id,
225 name: match subscription.kind {
226 Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
227 Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
228 Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
229 None => "Zed LLM Usage".to_string(),
230 },
231 status: subscription.stripe_subscription_status,
232 trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
233 maybe!({
234 let end_at = subscription.stripe_current_period_end?;
235 let end_at = DateTime::from_timestamp(end_at, 0)?;
236
237 Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
238 })
239 } else {
240 None
241 },
242 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
243 cancel_at
244 .and_utc()
245 .to_rfc3339_opts(SecondsFormat::Millis, true)
246 }),
247 is_cancelable: subscription.stripe_subscription_status.is_cancelable()
248 && subscription.stripe_cancel_at.is_none(),
249 })
250 .collect(),
251 }))
252}
253
254#[derive(Debug, Clone, Copy, Deserialize)]
255#[serde(rename_all = "snake_case")]
256enum ProductCode {
257 ZedPro,
258 ZedProTrial,
259}
260
261#[derive(Debug, Deserialize)]
262struct CreateBillingSubscriptionBody {
263 github_user_id: i32,
264 product: Option<ProductCode>,
265}
266
267#[derive(Debug, Serialize)]
268struct CreateBillingSubscriptionResponse {
269 checkout_session_url: String,
270}
271
272/// Initiates a Stripe Checkout session for creating a billing subscription.
273async fn create_billing_subscription(
274 Extension(app): Extension<Arc<AppState>>,
275 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
276) -> Result<Json<CreateBillingSubscriptionResponse>> {
277 let user = app
278 .db
279 .get_user_by_github_user_id(body.github_user_id)
280 .await?
281 .ok_or_else(|| anyhow!("user not found"))?;
282
283 let Some(stripe_client) = app.stripe_client.clone() else {
284 log::error!("failed to retrieve Stripe client");
285 Err(Error::http(
286 StatusCode::NOT_IMPLEMENTED,
287 "not supported".into(),
288 ))?
289 };
290 let Some(stripe_billing) = app.stripe_billing.clone() else {
291 log::error!("failed to retrieve Stripe billing object");
292 Err(Error::http(
293 StatusCode::NOT_IMPLEMENTED,
294 "not supported".into(),
295 ))?
296 };
297 let Some(llm_db) = app.llm_db.clone() else {
298 log::error!("failed to retrieve LLM database");
299 Err(Error::http(
300 StatusCode::NOT_IMPLEMENTED,
301 "not supported".into(),
302 ))?
303 };
304
305 if app.db.has_active_billing_subscription(user.id).await? {
306 return Err(Error::http(
307 StatusCode::CONFLICT,
308 "user already has an active subscription".into(),
309 ));
310 }
311
312 let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
313 if let Some(existing_billing_customer) = &existing_billing_customer {
314 if existing_billing_customer.has_overdue_invoices {
315 return Err(Error::http(
316 StatusCode::PAYMENT_REQUIRED,
317 "user has overdue invoices".into(),
318 ));
319 }
320 }
321
322 let customer_id = if let Some(existing_customer) = &existing_billing_customer {
323 CustomerId::from_str(&existing_customer.stripe_customer_id)
324 .context("failed to parse customer ID")?
325 } else {
326 let existing_customer = if let Some(email) = user.email_address.as_deref() {
327 let customers = Customer::list(
328 &stripe_client,
329 &stripe::ListCustomers {
330 email: Some(email),
331 ..Default::default()
332 },
333 )
334 .await?;
335
336 customers.data.first().cloned()
337 } else {
338 None
339 };
340
341 if let Some(existing_customer) = existing_customer {
342 existing_customer.id
343 } else {
344 let customer = Customer::create(
345 &stripe_client,
346 CreateCustomer {
347 email: user.email_address.as_deref(),
348 ..Default::default()
349 },
350 )
351 .await?;
352
353 customer.id
354 }
355 };
356
357 let success_url = format!(
358 "{}/account?checkout_complete=1",
359 app.config.zed_dot_dev_url()
360 );
361
362 let checkout_session_url = match body.product {
363 Some(ProductCode::ZedPro) => {
364 stripe_billing
365 .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
366 .await?
367 }
368 Some(ProductCode::ZedProTrial) => {
369 if let Some(existing_billing_customer) = &existing_billing_customer {
370 if existing_billing_customer.trial_started_at.is_some() {
371 return Err(Error::http(
372 StatusCode::FORBIDDEN,
373 "user already used free trial".into(),
374 ));
375 }
376 }
377
378 let feature_flags = app.db.get_user_flags(user.id).await?;
379
380 stripe_billing
381 .checkout_with_zed_pro_trial(
382 customer_id,
383 &user.github_login,
384 feature_flags,
385 &success_url,
386 )
387 .await?
388 }
389 None => {
390 let default_model = llm_db.model(
391 zed_llm_client::LanguageModelProvider::Anthropic,
392 "claude-3-7-sonnet",
393 )?;
394 let stripe_model = stripe_billing
395 .register_model_for_token_based_usage(default_model)
396 .await?;
397 stripe_billing
398 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
399 .await?
400 }
401 };
402
403 Ok(Json(CreateBillingSubscriptionResponse {
404 checkout_session_url,
405 }))
406}
407
408#[derive(Debug, PartialEq, Deserialize)]
409#[serde(rename_all = "snake_case")]
410enum ManageSubscriptionIntent {
411 /// The user intends to manage their subscription.
412 ///
413 /// This will open the Stripe billing portal without putting the user in a specific flow.
414 ManageSubscription,
415 /// The user intends to upgrade to Zed Pro.
416 UpgradeToPro,
417 /// The user intends to cancel their subscription.
418 Cancel,
419 /// The user intends to stop the cancellation of their subscription.
420 StopCancellation,
421}
422
423#[derive(Debug, Deserialize)]
424struct ManageBillingSubscriptionBody {
425 github_user_id: i32,
426 intent: ManageSubscriptionIntent,
427 /// The ID of the subscription to manage.
428 subscription_id: BillingSubscriptionId,
429}
430
431#[derive(Debug, Serialize)]
432struct ManageBillingSubscriptionResponse {
433 billing_portal_session_url: Option<String>,
434}
435
436/// Initiates a Stripe customer portal session for managing a billing subscription.
437async fn manage_billing_subscription(
438 Extension(app): Extension<Arc<AppState>>,
439 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
440) -> Result<Json<ManageBillingSubscriptionResponse>> {
441 let user = app
442 .db
443 .get_user_by_github_user_id(body.github_user_id)
444 .await?
445 .ok_or_else(|| anyhow!("user not found"))?;
446
447 let Some(stripe_client) = app.stripe_client.clone() else {
448 log::error!("failed to retrieve Stripe client");
449 Err(Error::http(
450 StatusCode::NOT_IMPLEMENTED,
451 "not supported".into(),
452 ))?
453 };
454
455 let Some(stripe_billing) = app.stripe_billing.clone() else {
456 log::error!("failed to retrieve Stripe billing object");
457 Err(Error::http(
458 StatusCode::NOT_IMPLEMENTED,
459 "not supported".into(),
460 ))?
461 };
462
463 let customer = app
464 .db
465 .get_billing_customer_by_user_id(user.id)
466 .await?
467 .ok_or_else(|| anyhow!("billing customer not found"))?;
468 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
469 .context("failed to parse customer ID")?;
470
471 let subscription = app
472 .db
473 .get_billing_subscription_by_id(body.subscription_id)
474 .await?
475 .ok_or_else(|| anyhow!("subscription not found"))?;
476 let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
477 .context("failed to parse subscription ID")?;
478
479 if body.intent == ManageSubscriptionIntent::StopCancellation {
480 let updated_stripe_subscription = Subscription::update(
481 &stripe_client,
482 &subscription_id,
483 stripe::UpdateSubscription {
484 cancel_at_period_end: Some(false),
485 ..Default::default()
486 },
487 )
488 .await?;
489
490 app.db
491 .update_billing_subscription(
492 subscription.id,
493 &UpdateBillingSubscriptionParams {
494 stripe_cancel_at: ActiveValue::set(
495 updated_stripe_subscription
496 .cancel_at
497 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
498 .map(|time| time.naive_utc()),
499 ),
500 ..Default::default()
501 },
502 )
503 .await?;
504
505 return Ok(Json(ManageBillingSubscriptionResponse {
506 billing_portal_session_url: None,
507 }));
508 }
509
510 let flow = match body.intent {
511 ManageSubscriptionIntent::ManageSubscription => None,
512 ManageSubscriptionIntent::UpgradeToPro => {
513 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
514 let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
515
516 let stripe_subscription =
517 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
518
519 let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
520 && stripe_subscription.items.data.iter().any(|item| {
521 item.price
522 .as_ref()
523 .map_or(false, |price| price.id == zed_pro_price_id)
524 });
525 if is_on_zed_pro_trial {
526 // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
527 Subscription::update(
528 &stripe_client,
529 &stripe_subscription.id,
530 stripe::UpdateSubscription {
531 trial_end: Some(stripe::Scheduled::now()),
532 ..Default::default()
533 },
534 )
535 .await?;
536
537 return Ok(Json(ManageBillingSubscriptionResponse {
538 billing_portal_session_url: None,
539 }));
540 }
541
542 let subscription_item_to_update = stripe_subscription
543 .items
544 .data
545 .iter()
546 .find_map(|item| {
547 let price = item.price.as_ref()?;
548
549 if price.id == zed_free_price_id {
550 Some(item.id.clone())
551 } else {
552 None
553 }
554 })
555 .ok_or_else(|| anyhow!("No subscription item to update"))?;
556
557 Some(CreateBillingPortalSessionFlowData {
558 type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
559 subscription_update_confirm: Some(
560 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
561 subscription: subscription.stripe_subscription_id,
562 items: vec![
563 CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
564 id: subscription_item_to_update.to_string(),
565 price: Some(zed_pro_price_id.to_string()),
566 quantity: Some(1),
567 },
568 ],
569 discounts: None,
570 },
571 ),
572 ..Default::default()
573 })
574 }
575 ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
576 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
577 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
578 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
579 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
580 return_url: format!("{}/account", app.config.zed_dot_dev_url()),
581 }),
582 ..Default::default()
583 }),
584 subscription_cancel: Some(
585 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
586 subscription: subscription.stripe_subscription_id,
587 retention: None,
588 },
589 ),
590 ..Default::default()
591 }),
592 ManageSubscriptionIntent::StopCancellation => unreachable!(),
593 };
594
595 let mut params = CreateBillingPortalSession::new(customer_id);
596 params.flow_data = flow;
597 let return_url = format!("{}/account", app.config.zed_dot_dev_url());
598 params.return_url = Some(&return_url);
599
600 let session = BillingPortalSession::create(&stripe_client, params).await?;
601
602 Ok(Json(ManageBillingSubscriptionResponse {
603 billing_portal_session_url: Some(session.url),
604 }))
605}
606
607/// The amount of time we wait in between each poll of Stripe events.
608///
609/// This value should strike a balance between:
610/// 1. Being short enough that we update quickly when something in Stripe changes
611/// 2. Being long enough that we don't eat into our rate limits.
612///
613/// As a point of reference, the Sequin folks say they have this at **500ms**:
614///
615/// > We poll the Stripe /events endpoint every 500ms per account
616/// >
617/// > — https://blog.sequinstream.com/events-not-webhooks/
618const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
619
620/// The maximum number of events to return per page.
621///
622/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
623///
624/// > Limit can range between 1 and 100, and the default is 10.
625const EVENTS_LIMIT_PER_PAGE: u64 = 100;
626
627/// The number of pages consisting entirely of already-processed events that we
628/// will see before we stop retrieving events.
629///
630/// This is used to prevent over-fetching the Stripe events API for events we've
631/// already seen and processed.
632const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
633
634/// Polls the Stripe events API periodically to reconcile the records in our
635/// database with the data in Stripe.
636pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
637 let Some(stripe_client) = app.stripe_client.clone() else {
638 log::warn!("failed to retrieve Stripe client");
639 return;
640 };
641
642 let executor = app.executor.clone();
643 executor.spawn_detached({
644 let executor = executor.clone();
645 async move {
646 loop {
647 poll_stripe_events(&app, &rpc_server, &stripe_client)
648 .await
649 .log_err();
650
651 executor.sleep(POLL_EVENTS_INTERVAL).await;
652 }
653 }
654 });
655}
656
657async fn poll_stripe_events(
658 app: &Arc<AppState>,
659 rpc_server: &Arc<Server>,
660 stripe_client: &stripe::Client,
661) -> anyhow::Result<()> {
662 fn event_type_to_string(event_type: EventType) -> String {
663 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
664 // so we need to unquote it.
665 event_type.to_string().trim_matches('"').to_string()
666 }
667
668 let event_types = [
669 EventType::CustomerCreated,
670 EventType::CustomerUpdated,
671 EventType::CustomerSubscriptionCreated,
672 EventType::CustomerSubscriptionUpdated,
673 EventType::CustomerSubscriptionPaused,
674 EventType::CustomerSubscriptionResumed,
675 EventType::CustomerSubscriptionDeleted,
676 ]
677 .into_iter()
678 .map(event_type_to_string)
679 .collect::<Vec<_>>();
680
681 let mut pages_of_already_processed_events = 0;
682 let mut unprocessed_events = Vec::new();
683
684 log::info!(
685 "Stripe events: starting retrieval for {}",
686 event_types.join(", ")
687 );
688 let mut params = ListEvents::new();
689 params.types = Some(event_types.clone());
690 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
691
692 let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
693 .await?
694 .paginate(params);
695
696 loop {
697 let processed_event_ids = {
698 let event_ids = event_pages
699 .page
700 .data
701 .iter()
702 .map(|event| event.id.as_str())
703 .collect::<Vec<_>>();
704 app.db
705 .get_processed_stripe_events_by_event_ids(&event_ids)
706 .await?
707 .into_iter()
708 .map(|event| event.stripe_event_id)
709 .collect::<Vec<_>>()
710 };
711
712 let mut processed_events_in_page = 0;
713 let events_in_page = event_pages.page.data.len();
714 for event in &event_pages.page.data {
715 if processed_event_ids.contains(&event.id.to_string()) {
716 processed_events_in_page += 1;
717 log::debug!("Stripe events: already processed '{}', skipping", event.id);
718 } else {
719 unprocessed_events.push(event.clone());
720 }
721 }
722
723 if processed_events_in_page == events_in_page {
724 pages_of_already_processed_events += 1;
725 }
726
727 if event_pages.page.has_more {
728 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
729 {
730 log::info!(
731 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
732 );
733 break;
734 } else {
735 log::info!("Stripe events: retrieving next page");
736 event_pages = event_pages.next(&stripe_client).await?;
737 }
738 } else {
739 break;
740 }
741 }
742
743 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
744
745 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
746 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
747
748 for event in unprocessed_events {
749 let event_id = event.id.clone();
750 let processed_event_params = CreateProcessedStripeEventParams {
751 stripe_event_id: event.id.to_string(),
752 stripe_event_type: event_type_to_string(event.type_),
753 stripe_event_created_timestamp: event.created,
754 };
755
756 // If the event has happened too far in the past, we don't want to
757 // process it and risk overwriting other more-recent updates.
758 //
759 // 1 day was chosen arbitrarily. This could be made longer or shorter.
760 let one_day = Duration::from_secs(24 * 60 * 60);
761 let a_day_ago = Utc::now() - one_day;
762 if a_day_ago.timestamp() > event.created {
763 log::info!(
764 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
765 event_id
766 );
767 app.db
768 .create_processed_stripe_event(&processed_event_params)
769 .await?;
770
771 return Ok(());
772 }
773
774 let process_result = match event.type_ {
775 EventType::CustomerCreated | EventType::CustomerUpdated => {
776 handle_customer_event(app, stripe_client, event).await
777 }
778 EventType::CustomerSubscriptionCreated
779 | EventType::CustomerSubscriptionUpdated
780 | EventType::CustomerSubscriptionPaused
781 | EventType::CustomerSubscriptionResumed
782 | EventType::CustomerSubscriptionDeleted => {
783 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
784 }
785 _ => Ok(()),
786 };
787
788 if let Some(()) = process_result
789 .with_context(|| format!("failed to process event {event_id} successfully"))
790 .log_err()
791 {
792 app.db
793 .create_processed_stripe_event(&processed_event_params)
794 .await?;
795 }
796 }
797
798 Ok(())
799}
800
801async fn handle_customer_event(
802 app: &Arc<AppState>,
803 _stripe_client: &stripe::Client,
804 event: stripe::Event,
805) -> anyhow::Result<()> {
806 let EventObject::Customer(customer) = event.data.object else {
807 bail!("unexpected event payload for {}", event.id);
808 };
809
810 log::info!("handling Stripe {} event: {}", event.type_, event.id);
811
812 let Some(email) = customer.email else {
813 log::info!("Stripe customer has no email: skipping");
814 return Ok(());
815 };
816
817 let Some(user) = app.db.get_user_by_email(&email).await? else {
818 log::info!("no user found for email: skipping");
819 return Ok(());
820 };
821
822 if let Some(existing_customer) = app
823 .db
824 .get_billing_customer_by_stripe_customer_id(&customer.id)
825 .await?
826 {
827 app.db
828 .update_billing_customer(
829 existing_customer.id,
830 &UpdateBillingCustomerParams {
831 // For now we just leave the information as-is, as it is not
832 // likely to change.
833 ..Default::default()
834 },
835 )
836 .await?;
837 } else {
838 app.db
839 .create_billing_customer(&CreateBillingCustomerParams {
840 user_id: user.id,
841 stripe_customer_id: customer.id.to_string(),
842 })
843 .await?;
844 }
845
846 Ok(())
847}
848
849async fn handle_customer_subscription_event(
850 app: &Arc<AppState>,
851 rpc_server: &Arc<Server>,
852 stripe_client: &stripe::Client,
853 event: stripe::Event,
854) -> anyhow::Result<()> {
855 let EventObject::Subscription(subscription) = event.data.object else {
856 bail!("unexpected event payload for {}", event.id);
857 };
858
859 log::info!("handling Stripe {} event: {}", event.type_, event.id);
860
861 let subscription_kind = maybe!(async {
862 let stripe_billing = app.stripe_billing.clone()?;
863
864 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
865 let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
866
867 subscription.items.data.iter().find_map(|item| {
868 let price = item.price.as_ref()?;
869
870 if price.id == zed_pro_price_id {
871 Some(if subscription.status == SubscriptionStatus::Trialing {
872 SubscriptionKind::ZedProTrial
873 } else {
874 SubscriptionKind::ZedPro
875 })
876 } else if price.id == zed_free_price_id {
877 Some(SubscriptionKind::ZedFree)
878 } else {
879 None
880 }
881 })
882 })
883 .await;
884
885 let billing_customer =
886 find_or_create_billing_customer(app, stripe_client, subscription.customer)
887 .await?
888 .ok_or_else(|| anyhow!("billing customer not found"))?;
889
890 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
891 if subscription.status == SubscriptionStatus::Trialing {
892 let current_period_start =
893 DateTime::from_timestamp(subscription.current_period_start, 0)
894 .ok_or_else(|| anyhow!("No trial subscription period start"))?;
895
896 app.db
897 .update_billing_customer(
898 billing_customer.id,
899 &UpdateBillingCustomerParams {
900 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
901 ..Default::default()
902 },
903 )
904 .await?;
905 }
906 }
907
908 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
909 && subscription
910 .cancellation_details
911 .as_ref()
912 .and_then(|details| details.reason)
913 .map_or(false, |reason| {
914 reason == CancellationDetailsReason::PaymentFailed
915 });
916
917 if was_canceled_due_to_payment_failure {
918 app.db
919 .update_billing_customer(
920 billing_customer.id,
921 &UpdateBillingCustomerParams {
922 has_overdue_invoices: ActiveValue::set(true),
923 ..Default::default()
924 },
925 )
926 .await?;
927 }
928
929 if let Some(existing_subscription) = app
930 .db
931 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
932 .await?
933 {
934 let llm_db = app
935 .llm_db
936 .clone()
937 .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
938
939 let new_period_start_at =
940 chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
941 .ok_or_else(|| anyhow!("No subscription period start"))?;
942 let new_period_end_at =
943 chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
944 .ok_or_else(|| anyhow!("No subscription period end"))?;
945
946 llm_db
947 .transfer_existing_subscription_usage(
948 billing_customer.user_id,
949 &existing_subscription,
950 subscription_kind,
951 new_period_start_at,
952 new_period_end_at,
953 )
954 .await?;
955
956 app.db
957 .update_billing_subscription(
958 existing_subscription.id,
959 &UpdateBillingSubscriptionParams {
960 billing_customer_id: ActiveValue::set(billing_customer.id),
961 kind: ActiveValue::set(subscription_kind),
962 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
963 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
964 stripe_cancel_at: ActiveValue::set(
965 subscription
966 .cancel_at
967 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
968 .map(|time| time.naive_utc()),
969 ),
970 stripe_cancellation_reason: ActiveValue::set(
971 subscription
972 .cancellation_details
973 .and_then(|details| details.reason)
974 .map(|reason| reason.into()),
975 ),
976 stripe_current_period_start: ActiveValue::set(Some(
977 subscription.current_period_start,
978 )),
979 stripe_current_period_end: ActiveValue::set(Some(
980 subscription.current_period_end,
981 )),
982 },
983 )
984 .await?;
985 } else {
986 // If the user already has an active billing subscription, ignore the
987 // event and return an `Ok` to signal that it was processed
988 // successfully.
989 //
990 // There is the possibility that this could cause us to not create a
991 // subscription in the following scenario:
992 //
993 // 1. User has an active subscription A
994 // 2. User cancels subscription A
995 // 3. User creates a new subscription B
996 // 4. We process the new subscription B before the cancellation of subscription A
997 // 5. User ends up with no subscriptions
998 //
999 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1000 if app
1001 .db
1002 .has_active_billing_subscription(billing_customer.user_id)
1003 .await?
1004 {
1005 log::info!(
1006 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1007 user_id = billing_customer.user_id,
1008 subscription_id = subscription.id
1009 );
1010 return Ok(());
1011 }
1012
1013 app.db
1014 .create_billing_subscription(&CreateBillingSubscriptionParams {
1015 billing_customer_id: billing_customer.id,
1016 kind: subscription_kind,
1017 stripe_subscription_id: subscription.id.to_string(),
1018 stripe_subscription_status: subscription.status.into(),
1019 stripe_cancellation_reason: subscription
1020 .cancellation_details
1021 .and_then(|details| details.reason)
1022 .map(|reason| reason.into()),
1023 stripe_current_period_start: Some(subscription.current_period_start),
1024 stripe_current_period_end: Some(subscription.current_period_end),
1025 })
1026 .await?;
1027 }
1028
1029 // When the user's subscription changes, we want to refresh their LLM tokens
1030 // to either grant/revoke access.
1031 rpc_server
1032 .refresh_llm_tokens_for_user(billing_customer.user_id)
1033 .await;
1034
1035 Ok(())
1036}
1037
1038#[derive(Debug, Deserialize)]
1039struct GetMonthlySpendParams {
1040 github_user_id: i32,
1041}
1042
1043#[derive(Debug, Serialize)]
1044struct GetMonthlySpendResponse {
1045 monthly_free_tier_spend_in_cents: u32,
1046 monthly_free_tier_allowance_in_cents: u32,
1047 monthly_spend_in_cents: u32,
1048}
1049
1050async fn get_monthly_spend(
1051 Extension(app): Extension<Arc<AppState>>,
1052 Query(params): Query<GetMonthlySpendParams>,
1053) -> Result<Json<GetMonthlySpendResponse>> {
1054 let user = app
1055 .db
1056 .get_user_by_github_user_id(params.github_user_id)
1057 .await?
1058 .ok_or_else(|| anyhow!("user not found"))?;
1059
1060 let Some(llm_db) = app.llm_db.clone() else {
1061 return Err(Error::http(
1062 StatusCode::NOT_IMPLEMENTED,
1063 "LLM database not available".into(),
1064 ));
1065 };
1066
1067 let free_tier = user
1068 .custom_llm_monthly_allowance_in_cents
1069 .map(|allowance| Cents(allowance as u32))
1070 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1071
1072 let spending_for_month = llm_db
1073 .get_user_spending_for_month(user.id, Utc::now())
1074 .await?;
1075
1076 let free_tier_spend = Cents::min(spending_for_month, free_tier);
1077 let monthly_spend = spending_for_month.saturating_sub(free_tier);
1078
1079 Ok(Json(GetMonthlySpendResponse {
1080 monthly_free_tier_spend_in_cents: free_tier_spend.0,
1081 monthly_free_tier_allowance_in_cents: free_tier.0,
1082 monthly_spend_in_cents: monthly_spend.0,
1083 }))
1084}
1085
1086#[derive(Debug, Deserialize)]
1087struct GetCurrentUsageParams {
1088 github_user_id: i32,
1089}
1090
1091#[derive(Debug, Serialize)]
1092struct UsageCounts {
1093 pub used: i32,
1094 pub limit: Option<i32>,
1095 pub remaining: Option<i32>,
1096}
1097
1098#[derive(Debug, Serialize)]
1099struct ModelRequestUsage {
1100 pub model: String,
1101 pub mode: CompletionMode,
1102 pub requests: i32,
1103}
1104
1105#[derive(Debug, Serialize)]
1106struct GetCurrentUsageResponse {
1107 pub model_requests: UsageCounts,
1108 pub model_request_usage: Vec<ModelRequestUsage>,
1109 pub edit_predictions: UsageCounts,
1110}
1111
1112async fn get_current_usage(
1113 Extension(app): Extension<Arc<AppState>>,
1114 Query(params): Query<GetCurrentUsageParams>,
1115) -> Result<Json<GetCurrentUsageResponse>> {
1116 let user = app
1117 .db
1118 .get_user_by_github_user_id(params.github_user_id)
1119 .await?
1120 .ok_or_else(|| anyhow!("user not found"))?;
1121
1122 let Some(llm_db) = app.llm_db.clone() else {
1123 return Err(Error::http(
1124 StatusCode::NOT_IMPLEMENTED,
1125 "LLM database not available".into(),
1126 ));
1127 };
1128
1129 let empty_usage = GetCurrentUsageResponse {
1130 model_requests: UsageCounts {
1131 used: 0,
1132 limit: Some(0),
1133 remaining: Some(0),
1134 },
1135 model_request_usage: Vec::new(),
1136 edit_predictions: UsageCounts {
1137 used: 0,
1138 limit: Some(0),
1139 remaining: Some(0),
1140 },
1141 };
1142
1143 let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1144 return Ok(Json(empty_usage));
1145 };
1146
1147 let subscription_period = maybe!({
1148 let period_start_at = subscription.current_period_start_at()?;
1149 let period_end_at = subscription.current_period_end_at()?;
1150
1151 Some((period_start_at, period_end_at))
1152 });
1153
1154 let Some((period_start_at, period_end_at)) = subscription_period else {
1155 return Ok(Json(empty_usage));
1156 };
1157
1158 let usage = llm_db
1159 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1160 .await?;
1161 let Some(usage) = usage else {
1162 return Ok(Json(empty_usage));
1163 };
1164
1165 let plan = match usage.plan {
1166 SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1167 SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1168 SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1169 };
1170
1171 let model_requests_limit = match plan.model_requests_limit() {
1172 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1173 zed_llm_client::UsageLimit::Unlimited => None,
1174 };
1175 let edit_prediction_limit = match plan.edit_predictions_limit() {
1176 zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1177 zed_llm_client::UsageLimit::Unlimited => None,
1178 };
1179
1180 let subscription_usage_meters = llm_db
1181 .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1182 .await?;
1183
1184 let model_request_usage = subscription_usage_meters
1185 .into_iter()
1186 .filter_map(|(usage_meter, _usage)| {
1187 let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1188
1189 Some(ModelRequestUsage {
1190 model: model.name.clone(),
1191 mode: usage_meter.mode,
1192 requests: usage_meter.requests,
1193 })
1194 })
1195 .collect::<Vec<_>>();
1196
1197 Ok(Json(GetCurrentUsageResponse {
1198 model_requests: UsageCounts {
1199 used: usage.model_requests,
1200 limit: model_requests_limit,
1201 remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1202 },
1203 model_request_usage,
1204 edit_predictions: UsageCounts {
1205 used: usage.edit_predictions,
1206 limit: edit_prediction_limit,
1207 remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1208 },
1209 }))
1210}
1211
1212impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1213 fn from(value: SubscriptionStatus) -> Self {
1214 match value {
1215 SubscriptionStatus::Incomplete => Self::Incomplete,
1216 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1217 SubscriptionStatus::Trialing => Self::Trialing,
1218 SubscriptionStatus::Active => Self::Active,
1219 SubscriptionStatus::PastDue => Self::PastDue,
1220 SubscriptionStatus::Canceled => Self::Canceled,
1221 SubscriptionStatus::Unpaid => Self::Unpaid,
1222 SubscriptionStatus::Paused => Self::Paused,
1223 }
1224 }
1225}
1226
1227impl From<CancellationDetailsReason> for StripeCancellationReason {
1228 fn from(value: CancellationDetailsReason) -> Self {
1229 match value {
1230 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1231 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1232 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1233 }
1234 }
1235}
1236
1237/// Finds or creates a billing customer using the provided customer.
1238async fn find_or_create_billing_customer(
1239 app: &Arc<AppState>,
1240 stripe_client: &stripe::Client,
1241 customer_or_id: Expandable<Customer>,
1242) -> anyhow::Result<Option<billing_customer::Model>> {
1243 let customer_id = match &customer_or_id {
1244 Expandable::Id(id) => id,
1245 Expandable::Object(customer) => customer.id.as_ref(),
1246 };
1247
1248 // If we already have a billing customer record associated with the Stripe customer,
1249 // there's nothing more we need to do.
1250 if let Some(billing_customer) = app
1251 .db
1252 .get_billing_customer_by_stripe_customer_id(customer_id)
1253 .await?
1254 {
1255 return Ok(Some(billing_customer));
1256 }
1257
1258 // If all we have is a customer ID, resolve it to a full customer record by
1259 // hitting the Stripe API.
1260 let customer = match customer_or_id {
1261 Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1262 Expandable::Object(customer) => *customer,
1263 };
1264
1265 let Some(email) = customer.email else {
1266 return Ok(None);
1267 };
1268
1269 let Some(user) = app.db.get_user_by_email(&email).await? else {
1270 return Ok(None);
1271 };
1272
1273 let billing_customer = app
1274 .db
1275 .create_billing_customer(&CreateBillingCustomerParams {
1276 user_id: user.id,
1277 stripe_customer_id: customer.id.to_string(),
1278 })
1279 .await?;
1280
1281 Ok(Some(billing_customer))
1282}
1283
1284const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1285
1286pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
1287 let Some(stripe_billing) = app.stripe_billing.clone() else {
1288 log::warn!("failed to retrieve Stripe billing object");
1289 return;
1290 };
1291 let Some(llm_db) = app.llm_db.clone() else {
1292 log::warn!("failed to retrieve LLM database");
1293 return;
1294 };
1295
1296 let executor = app.executor.clone();
1297 executor.spawn_detached({
1298 let executor = executor.clone();
1299 async move {
1300 loop {
1301 sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
1302 .await
1303 .context("failed to sync LLM usage to Stripe")
1304 .trace_err();
1305 executor
1306 .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
1307 .await;
1308 }
1309 }
1310 });
1311}
1312
1313async fn sync_token_usage_with_stripe(
1314 app: &Arc<AppState>,
1315 llm_db: &Arc<LlmDatabase>,
1316 stripe_billing: &Arc<StripeBilling>,
1317) -> anyhow::Result<()> {
1318 let events = llm_db.get_billing_events().await?;
1319 let user_ids = events
1320 .iter()
1321 .map(|(event, _)| event.user_id)
1322 .collect::<HashSet<UserId>>();
1323 let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1324
1325 for (event, model) in events {
1326 let Some((stripe_db_customer, stripe_db_subscription)) =
1327 stripe_subscriptions.get(&event.user_id)
1328 else {
1329 tracing::warn!(
1330 user_id = event.user_id.0,
1331 "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1332 );
1333 continue;
1334 };
1335 let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1336 .stripe_subscription_id
1337 .parse()
1338 .context("failed to parse stripe subscription id from db")?;
1339 let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1340 .stripe_customer_id
1341 .parse()
1342 .context("failed to parse stripe customer id from db")?;
1343
1344 let stripe_model = stripe_billing
1345 .register_model_for_token_based_usage(&model)
1346 .await?;
1347 stripe_billing
1348 .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1349 .await?;
1350 stripe_billing
1351 .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
1352 .await?;
1353 llm_db.consume_billing_event(event.id).await?;
1354 }
1355
1356 Ok(())
1357}
1358
1359const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1360
1361pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1362 let Some(stripe_billing) = app.stripe_billing.clone() else {
1363 log::warn!("failed to retrieve Stripe billing object");
1364 return;
1365 };
1366 let Some(llm_db) = app.llm_db.clone() else {
1367 log::warn!("failed to retrieve LLM database");
1368 return;
1369 };
1370
1371 let executor = app.executor.clone();
1372 executor.spawn_detached({
1373 let executor = executor.clone();
1374 async move {
1375 loop {
1376 sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1377 .await
1378 .context("failed to sync LLM request usage to Stripe")
1379 .trace_err();
1380 executor
1381 .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1382 .await;
1383 }
1384 }
1385 });
1386}
1387
1388async fn sync_model_request_usage_with_stripe(
1389 app: &Arc<AppState>,
1390 llm_db: &Arc<LlmDatabase>,
1391 stripe_billing: &Arc<StripeBilling>,
1392) -> anyhow::Result<()> {
1393 let usage_meters = llm_db
1394 .get_current_subscription_usage_meters(Utc::now())
1395 .await?;
1396 let user_ids = usage_meters
1397 .iter()
1398 .map(|(_, usage)| usage.user_id)
1399 .collect::<HashSet<UserId>>();
1400 let billing_subscriptions = app
1401 .db
1402 .get_active_zed_pro_billing_subscriptions(user_ids)
1403 .await?;
1404
1405 let claude_3_5_sonnet = stripe_billing
1406 .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1407 .await?;
1408 let claude_3_7_sonnet = stripe_billing
1409 .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1410 .await?;
1411 let claude_3_7_sonnet_max = stripe_billing
1412 .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1413 .await?;
1414
1415 for (usage_meter, usage) in usage_meters {
1416 maybe!(async {
1417 let Some((billing_customer, billing_subscription)) =
1418 billing_subscriptions.get(&usage.user_id)
1419 else {
1420 bail!(
1421 "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1422 usage.user_id
1423 );
1424 };
1425
1426 let stripe_customer_id = billing_customer
1427 .stripe_customer_id
1428 .parse::<stripe::CustomerId>()
1429 .context("failed to parse Stripe customer ID from database")?;
1430 let stripe_subscription_id = billing_subscription
1431 .stripe_subscription_id
1432 .parse::<stripe::SubscriptionId>()
1433 .context("failed to parse Stripe subscription ID from database")?;
1434
1435 let model = llm_db.model_by_id(usage_meter.model_id)?;
1436
1437 let (price, meter_event_name) = match model.name.as_str() {
1438 "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1439 "claude-3-7-sonnet" => match usage_meter.mode {
1440 CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1441 CompletionMode::Max => {
1442 (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1443 }
1444 },
1445 model_name => {
1446 bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1447 }
1448 };
1449
1450 stripe_billing
1451 .subscribe_to_price(&stripe_subscription_id, price)
1452 .await?;
1453 stripe_billing
1454 .bill_model_request_usage(
1455 &stripe_customer_id,
1456 meter_event_name,
1457 usage_meter.requests,
1458 )
1459 .await?;
1460
1461 Ok(())
1462 })
1463 .await
1464 .log_err();
1465 }
1466
1467 Ok(())
1468}