billing.rs

   1use anyhow::{Context, anyhow, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  19    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  20    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
  21    EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
  22};
  23use util::{ResultExt, maybe};
  24
  25use crate::api::events::SnowflakeRow;
  26use crate::db::billing_subscription::{
  27    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  28};
  29use crate::llm::db::subscription_usage_meter::CompletionMode;
  30use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
  31use crate::rpc::{ResultExt as _, Server};
  32use crate::{AppState, Cents, Error, Result};
  33use crate::{db::UserId, llm::db::LlmDatabase};
  34use crate::{
  35    db::{
  36        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  37        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  38        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  39    },
  40    stripe_billing::StripeBilling,
  41};
  42
  43pub fn router() -> Router {
  44    Router::new()
  45        .route(
  46            "/billing/preferences",
  47            get(get_billing_preferences).put(update_billing_preferences),
  48        )
  49        .route(
  50            "/billing/subscriptions",
  51            get(list_billing_subscriptions).post(create_billing_subscription),
  52        )
  53        .route(
  54            "/billing/subscriptions/manage",
  55            post(manage_billing_subscription),
  56        )
  57        .route("/billing/monthly_spend", get(get_monthly_spend))
  58        .route("/billing/usage", get(get_current_usage))
  59}
  60
  61#[derive(Debug, Deserialize)]
  62struct GetBillingPreferencesParams {
  63    github_user_id: i32,
  64}
  65
  66#[derive(Debug, Serialize)]
  67struct BillingPreferencesResponse {
  68    max_monthly_llm_usage_spending_in_cents: i32,
  69    model_request_overages_enabled: bool,
  70    model_request_overages_spend_limit_in_cents: i32,
  71}
  72
  73async fn get_billing_preferences(
  74    Extension(app): Extension<Arc<AppState>>,
  75    Query(params): Query<GetBillingPreferencesParams>,
  76) -> Result<Json<BillingPreferencesResponse>> {
  77    let user = app
  78        .db
  79        .get_user_by_github_user_id(params.github_user_id)
  80        .await?
  81        .ok_or_else(|| anyhow!("user not found"))?;
  82
  83    let preferences = app.db.get_billing_preferences(user.id).await?;
  84
  85    Ok(Json(BillingPreferencesResponse {
  86        max_monthly_llm_usage_spending_in_cents: preferences
  87            .as_ref()
  88            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
  89                preferences.max_monthly_llm_usage_spending_in_cents
  90            }),
  91        model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
  92            preferences.model_request_overages_enabled
  93        }),
  94        model_request_overages_spend_limit_in_cents: preferences
  95            .as_ref()
  96            .map_or(0, |preferences| {
  97                preferences.model_request_overages_spend_limit_in_cents
  98            }),
  99    }))
 100}
 101
 102#[derive(Debug, Deserialize)]
 103struct UpdateBillingPreferencesBody {
 104    github_user_id: i32,
 105    #[serde(default)]
 106    max_monthly_llm_usage_spending_in_cents: i32,
 107    #[serde(default)]
 108    model_request_overages_enabled: bool,
 109    #[serde(default)]
 110    model_request_overages_spend_limit_in_cents: i32,
 111}
 112
 113async fn update_billing_preferences(
 114    Extension(app): Extension<Arc<AppState>>,
 115    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
 116    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 117) -> Result<Json<BillingPreferencesResponse>> {
 118    let user = app
 119        .db
 120        .get_user_by_github_user_id(body.github_user_id)
 121        .await?
 122        .ok_or_else(|| anyhow!("user not found"))?;
 123
 124    let max_monthly_llm_usage_spending_in_cents =
 125        body.max_monthly_llm_usage_spending_in_cents.max(0);
 126    let model_request_overages_spend_limit_in_cents =
 127        body.model_request_overages_spend_limit_in_cents.max(0);
 128
 129    let billing_preferences =
 130        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 131            app.db
 132                .update_billing_preferences(
 133                    user.id,
 134                    &UpdateBillingPreferencesParams {
 135                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 136                            max_monthly_llm_usage_spending_in_cents,
 137                        ),
 138                        model_request_overages_enabled: ActiveValue::set(
 139                            body.model_request_overages_enabled,
 140                        ),
 141                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 142                            model_request_overages_spend_limit_in_cents,
 143                        ),
 144                    },
 145                )
 146                .await?
 147        } else {
 148            app.db
 149                .create_billing_preferences(
 150                    user.id,
 151                    &crate::db::CreateBillingPreferencesParams {
 152                        max_monthly_llm_usage_spending_in_cents,
 153                        model_request_overages_enabled: body.model_request_overages_enabled,
 154                        model_request_overages_spend_limit_in_cents,
 155                    },
 156                )
 157                .await?
 158        };
 159
 160    SnowflakeRow::new(
 161        "Billing Preferences Updated",
 162        Some(user.metrics_id),
 163        user.admin,
 164        None,
 165        json!({
 166            "user_id": user.id,
 167            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 168            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 169            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 170        }),
 171    )
 172    .write(&app.kinesis_client, &app.config.kinesis_stream)
 173    .await
 174    .log_err();
 175
 176    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 177
 178    Ok(Json(BillingPreferencesResponse {
 179        max_monthly_llm_usage_spending_in_cents: billing_preferences
 180            .max_monthly_llm_usage_spending_in_cents,
 181        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 182        model_request_overages_spend_limit_in_cents: billing_preferences
 183            .model_request_overages_spend_limit_in_cents,
 184    }))
 185}
 186
 187#[derive(Debug, Deserialize)]
 188struct ListBillingSubscriptionsParams {
 189    github_user_id: i32,
 190}
 191
 192#[derive(Debug, Serialize)]
 193struct BillingSubscriptionJson {
 194    id: BillingSubscriptionId,
 195    name: String,
 196    status: StripeSubscriptionStatus,
 197    trial_end_at: Option<String>,
 198    cancel_at: Option<String>,
 199    /// Whether this subscription can be canceled.
 200    is_cancelable: bool,
 201}
 202
 203#[derive(Debug, Serialize)]
 204struct ListBillingSubscriptionsResponse {
 205    subscriptions: Vec<BillingSubscriptionJson>,
 206}
 207
 208async fn list_billing_subscriptions(
 209    Extension(app): Extension<Arc<AppState>>,
 210    Query(params): Query<ListBillingSubscriptionsParams>,
 211) -> Result<Json<ListBillingSubscriptionsResponse>> {
 212    let user = app
 213        .db
 214        .get_user_by_github_user_id(params.github_user_id)
 215        .await?
 216        .ok_or_else(|| anyhow!("user not found"))?;
 217
 218    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 219
 220    Ok(Json(ListBillingSubscriptionsResponse {
 221        subscriptions: subscriptions
 222            .into_iter()
 223            .map(|subscription| BillingSubscriptionJson {
 224                id: subscription.id,
 225                name: match subscription.kind {
 226                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 227                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
 228                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
 229                    None => "Zed LLM Usage".to_string(),
 230                },
 231                status: subscription.stripe_subscription_status,
 232                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
 233                    maybe!({
 234                        let end_at = subscription.stripe_current_period_end?;
 235                        let end_at = DateTime::from_timestamp(end_at, 0)?;
 236
 237                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
 238                    })
 239                } else {
 240                    None
 241                },
 242                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 243                    cancel_at
 244                        .and_utc()
 245                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 246                }),
 247                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
 248                    && subscription.stripe_cancel_at.is_none(),
 249            })
 250            .collect(),
 251    }))
 252}
 253
 254#[derive(Debug, Clone, Copy, Deserialize)]
 255#[serde(rename_all = "snake_case")]
 256enum ProductCode {
 257    ZedPro,
 258    ZedProTrial,
 259    ZedFree,
 260}
 261
 262#[derive(Debug, Deserialize)]
 263struct CreateBillingSubscriptionBody {
 264    github_user_id: i32,
 265    product: Option<ProductCode>,
 266}
 267
 268#[derive(Debug, Serialize)]
 269struct CreateBillingSubscriptionResponse {
 270    checkout_session_url: String,
 271}
 272
 273/// Initiates a Stripe Checkout session for creating a billing subscription.
 274async fn create_billing_subscription(
 275    Extension(app): Extension<Arc<AppState>>,
 276    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 277) -> Result<Json<CreateBillingSubscriptionResponse>> {
 278    let user = app
 279        .db
 280        .get_user_by_github_user_id(body.github_user_id)
 281        .await?
 282        .ok_or_else(|| anyhow!("user not found"))?;
 283
 284    let Some(stripe_client) = app.stripe_client.clone() else {
 285        log::error!("failed to retrieve Stripe client");
 286        Err(Error::http(
 287            StatusCode::NOT_IMPLEMENTED,
 288            "not supported".into(),
 289        ))?
 290    };
 291    let Some(stripe_billing) = app.stripe_billing.clone() else {
 292        log::error!("failed to retrieve Stripe billing object");
 293        Err(Error::http(
 294            StatusCode::NOT_IMPLEMENTED,
 295            "not supported".into(),
 296        ))?
 297    };
 298    let Some(llm_db) = app.llm_db.clone() else {
 299        log::error!("failed to retrieve LLM database");
 300        Err(Error::http(
 301            StatusCode::NOT_IMPLEMENTED,
 302            "not supported".into(),
 303        ))?
 304    };
 305
 306    if app.db.has_active_billing_subscription(user.id).await? {
 307        return Err(Error::http(
 308            StatusCode::CONFLICT,
 309            "user already has an active subscription".into(),
 310        ));
 311    }
 312
 313    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 314    if let Some(existing_billing_customer) = &existing_billing_customer {
 315        if existing_billing_customer.has_overdue_invoices {
 316            return Err(Error::http(
 317                StatusCode::PAYMENT_REQUIRED,
 318                "user has overdue invoices".into(),
 319            ));
 320        }
 321    }
 322
 323    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 324        CustomerId::from_str(&existing_customer.stripe_customer_id)
 325            .context("failed to parse customer ID")?
 326    } else {
 327        let existing_customer = if let Some(email) = user.email_address.as_deref() {
 328            let customers = Customer::list(
 329                &stripe_client,
 330                &stripe::ListCustomers {
 331                    email: Some(email),
 332                    ..Default::default()
 333                },
 334            )
 335            .await?;
 336
 337            customers.data.first().cloned()
 338        } else {
 339            None
 340        };
 341
 342        if let Some(existing_customer) = existing_customer {
 343            existing_customer.id
 344        } else {
 345            let customer = Customer::create(
 346                &stripe_client,
 347                CreateCustomer {
 348                    email: user.email_address.as_deref(),
 349                    ..Default::default()
 350                },
 351            )
 352            .await?;
 353
 354            customer.id
 355        }
 356    };
 357
 358    let success_url = format!(
 359        "{}/account?checkout_complete=1",
 360        app.config.zed_dot_dev_url()
 361    );
 362
 363    let checkout_session_url = match body.product {
 364        Some(ProductCode::ZedPro) => {
 365            stripe_billing
 366                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
 367                .await?
 368        }
 369        Some(ProductCode::ZedProTrial) => {
 370            if let Some(existing_billing_customer) = &existing_billing_customer {
 371                if existing_billing_customer.trial_started_at.is_some() {
 372                    return Err(Error::http(
 373                        StatusCode::FORBIDDEN,
 374                        "user already used free trial".into(),
 375                    ));
 376                }
 377            }
 378
 379            let feature_flags = app.db.get_user_flags(user.id).await?;
 380
 381            stripe_billing
 382                .checkout_with_zed_pro_trial(
 383                    customer_id,
 384                    &user.github_login,
 385                    feature_flags,
 386                    &success_url,
 387                )
 388                .await?
 389        }
 390        Some(ProductCode::ZedFree) => {
 391            stripe_billing
 392                .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
 393                .await?
 394        }
 395        None => {
 396            let default_model = llm_db.model(
 397                zed_llm_client::LanguageModelProvider::Anthropic,
 398                "claude-3-7-sonnet",
 399            )?;
 400            let stripe_model = stripe_billing
 401                .register_model_for_token_based_usage(default_model)
 402                .await?;
 403            stripe_billing
 404                .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
 405                .await?
 406        }
 407    };
 408
 409    Ok(Json(CreateBillingSubscriptionResponse {
 410        checkout_session_url,
 411    }))
 412}
 413
 414#[derive(Debug, PartialEq, Deserialize)]
 415#[serde(rename_all = "snake_case")]
 416enum ManageSubscriptionIntent {
 417    /// The user intends to manage their subscription.
 418    ///
 419    /// This will open the Stripe billing portal without putting the user in a specific flow.
 420    ManageSubscription,
 421    /// The user intends to upgrade to Zed Pro.
 422    UpgradeToPro,
 423    /// The user intends to cancel their subscription.
 424    Cancel,
 425    /// The user intends to stop the cancellation of their subscription.
 426    StopCancellation,
 427}
 428
 429#[derive(Debug, Deserialize)]
 430struct ManageBillingSubscriptionBody {
 431    github_user_id: i32,
 432    intent: ManageSubscriptionIntent,
 433    /// The ID of the subscription to manage.
 434    subscription_id: BillingSubscriptionId,
 435}
 436
 437#[derive(Debug, Serialize)]
 438struct ManageBillingSubscriptionResponse {
 439    billing_portal_session_url: Option<String>,
 440}
 441
 442/// Initiates a Stripe customer portal session for managing a billing subscription.
 443async fn manage_billing_subscription(
 444    Extension(app): Extension<Arc<AppState>>,
 445    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 446) -> Result<Json<ManageBillingSubscriptionResponse>> {
 447    let user = app
 448        .db
 449        .get_user_by_github_user_id(body.github_user_id)
 450        .await?
 451        .ok_or_else(|| anyhow!("user not found"))?;
 452
 453    let Some(stripe_client) = app.stripe_client.clone() else {
 454        log::error!("failed to retrieve Stripe client");
 455        Err(Error::http(
 456            StatusCode::NOT_IMPLEMENTED,
 457            "not supported".into(),
 458        ))?
 459    };
 460
 461    let Some(stripe_billing) = app.stripe_billing.clone() else {
 462        log::error!("failed to retrieve Stripe billing object");
 463        Err(Error::http(
 464            StatusCode::NOT_IMPLEMENTED,
 465            "not supported".into(),
 466        ))?
 467    };
 468
 469    let customer = app
 470        .db
 471        .get_billing_customer_by_user_id(user.id)
 472        .await?
 473        .ok_or_else(|| anyhow!("billing customer not found"))?;
 474    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 475        .context("failed to parse customer ID")?;
 476
 477    let subscription = app
 478        .db
 479        .get_billing_subscription_by_id(body.subscription_id)
 480        .await?
 481        .ok_or_else(|| anyhow!("subscription not found"))?;
 482    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 483        .context("failed to parse subscription ID")?;
 484
 485    if body.intent == ManageSubscriptionIntent::StopCancellation {
 486        let updated_stripe_subscription = Subscription::update(
 487            &stripe_client,
 488            &subscription_id,
 489            stripe::UpdateSubscription {
 490                cancel_at_period_end: Some(false),
 491                ..Default::default()
 492            },
 493        )
 494        .await?;
 495
 496        app.db
 497            .update_billing_subscription(
 498                subscription.id,
 499                &UpdateBillingSubscriptionParams {
 500                    stripe_cancel_at: ActiveValue::set(
 501                        updated_stripe_subscription
 502                            .cancel_at
 503                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 504                            .map(|time| time.naive_utc()),
 505                    ),
 506                    ..Default::default()
 507                },
 508            )
 509            .await?;
 510
 511        return Ok(Json(ManageBillingSubscriptionResponse {
 512            billing_portal_session_url: None,
 513        }));
 514    }
 515
 516    let flow = match body.intent {
 517        ManageSubscriptionIntent::ManageSubscription => None,
 518        ManageSubscriptionIntent::UpgradeToPro => {
 519            let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
 520            let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
 521
 522            let stripe_subscription =
 523                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 524
 525            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 526                && stripe_subscription.items.data.iter().any(|item| {
 527                    item.price
 528                        .as_ref()
 529                        .map_or(false, |price| price.id == zed_pro_price_id)
 530                });
 531            if is_on_zed_pro_trial {
 532                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 533                Subscription::update(
 534                    &stripe_client,
 535                    &stripe_subscription.id,
 536                    stripe::UpdateSubscription {
 537                        trial_end: Some(stripe::Scheduled::now()),
 538                        ..Default::default()
 539                    },
 540                )
 541                .await?;
 542
 543                return Ok(Json(ManageBillingSubscriptionResponse {
 544                    billing_portal_session_url: None,
 545                }));
 546            }
 547
 548            let subscription_item_to_update = stripe_subscription
 549                .items
 550                .data
 551                .iter()
 552                .find_map(|item| {
 553                    let price = item.price.as_ref()?;
 554
 555                    if price.id == zed_free_price_id {
 556                        Some(item.id.clone())
 557                    } else {
 558                        None
 559                    }
 560                })
 561                .ok_or_else(|| anyhow!("No subscription item to update"))?;
 562
 563            Some(CreateBillingPortalSessionFlowData {
 564                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 565                subscription_update_confirm: Some(
 566                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 567                        subscription: subscription.stripe_subscription_id,
 568                        items: vec![
 569                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 570                                id: subscription_item_to_update.to_string(),
 571                                price: Some(zed_pro_price_id.to_string()),
 572                                quantity: Some(1),
 573                            },
 574                        ],
 575                        discounts: None,
 576                    },
 577                ),
 578                ..Default::default()
 579            })
 580        }
 581        ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
 582            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 583            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 584                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 585                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 586                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 587                }),
 588                ..Default::default()
 589            }),
 590            subscription_cancel: Some(
 591                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 592                    subscription: subscription.stripe_subscription_id,
 593                    retention: None,
 594                },
 595            ),
 596            ..Default::default()
 597        }),
 598        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 599    };
 600
 601    let mut params = CreateBillingPortalSession::new(customer_id);
 602    params.flow_data = flow;
 603    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 604    params.return_url = Some(&return_url);
 605
 606    let session = BillingPortalSession::create(&stripe_client, params).await?;
 607
 608    Ok(Json(ManageBillingSubscriptionResponse {
 609        billing_portal_session_url: Some(session.url),
 610    }))
 611}
 612
 613/// The amount of time we wait in between each poll of Stripe events.
 614///
 615/// This value should strike a balance between:
 616///   1. Being short enough that we update quickly when something in Stripe changes
 617///   2. Being long enough that we don't eat into our rate limits.
 618///
 619/// As a point of reference, the Sequin folks say they have this at **500ms**:
 620///
 621/// > We poll the Stripe /events endpoint every 500ms per account
 622/// >
 623/// > — https://blog.sequinstream.com/events-not-webhooks/
 624const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 625
 626/// The maximum number of events to return per page.
 627///
 628/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 629///
 630/// > Limit can range between 1 and 100, and the default is 10.
 631const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 632
 633/// The number of pages consisting entirely of already-processed events that we
 634/// will see before we stop retrieving events.
 635///
 636/// This is used to prevent over-fetching the Stripe events API for events we've
 637/// already seen and processed.
 638const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 639
 640/// Polls the Stripe events API periodically to reconcile the records in our
 641/// database with the data in Stripe.
 642pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 643    let Some(stripe_client) = app.stripe_client.clone() else {
 644        log::warn!("failed to retrieve Stripe client");
 645        return;
 646    };
 647
 648    let executor = app.executor.clone();
 649    executor.spawn_detached({
 650        let executor = executor.clone();
 651        async move {
 652            loop {
 653                poll_stripe_events(&app, &rpc_server, &stripe_client)
 654                    .await
 655                    .log_err();
 656
 657                executor.sleep(POLL_EVENTS_INTERVAL).await;
 658            }
 659        }
 660    });
 661}
 662
 663async fn poll_stripe_events(
 664    app: &Arc<AppState>,
 665    rpc_server: &Arc<Server>,
 666    stripe_client: &stripe::Client,
 667) -> anyhow::Result<()> {
 668    fn event_type_to_string(event_type: EventType) -> String {
 669        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 670        // so we need to unquote it.
 671        event_type.to_string().trim_matches('"').to_string()
 672    }
 673
 674    let event_types = [
 675        EventType::CustomerCreated,
 676        EventType::CustomerUpdated,
 677        EventType::CustomerSubscriptionCreated,
 678        EventType::CustomerSubscriptionUpdated,
 679        EventType::CustomerSubscriptionPaused,
 680        EventType::CustomerSubscriptionResumed,
 681        EventType::CustomerSubscriptionDeleted,
 682    ]
 683    .into_iter()
 684    .map(event_type_to_string)
 685    .collect::<Vec<_>>();
 686
 687    let mut pages_of_already_processed_events = 0;
 688    let mut unprocessed_events = Vec::new();
 689
 690    log::info!(
 691        "Stripe events: starting retrieval for {}",
 692        event_types.join(", ")
 693    );
 694    let mut params = ListEvents::new();
 695    params.types = Some(event_types.clone());
 696    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 697
 698    let mut event_pages = stripe::Event::list(&stripe_client, &params)
 699        .await?
 700        .paginate(params);
 701
 702    loop {
 703        let processed_event_ids = {
 704            let event_ids = event_pages
 705                .page
 706                .data
 707                .iter()
 708                .map(|event| event.id.as_str())
 709                .collect::<Vec<_>>();
 710            app.db
 711                .get_processed_stripe_events_by_event_ids(&event_ids)
 712                .await?
 713                .into_iter()
 714                .map(|event| event.stripe_event_id)
 715                .collect::<Vec<_>>()
 716        };
 717
 718        let mut processed_events_in_page = 0;
 719        let events_in_page = event_pages.page.data.len();
 720        for event in &event_pages.page.data {
 721            if processed_event_ids.contains(&event.id.to_string()) {
 722                processed_events_in_page += 1;
 723                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 724            } else {
 725                unprocessed_events.push(event.clone());
 726            }
 727        }
 728
 729        if processed_events_in_page == events_in_page {
 730            pages_of_already_processed_events += 1;
 731        }
 732
 733        if event_pages.page.has_more {
 734            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 735            {
 736                log::info!(
 737                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 738                );
 739                break;
 740            } else {
 741                log::info!("Stripe events: retrieving next page");
 742                event_pages = event_pages.next(&stripe_client).await?;
 743            }
 744        } else {
 745            break;
 746        }
 747    }
 748
 749    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 750
 751    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 752    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 753
 754    for event in unprocessed_events {
 755        let event_id = event.id.clone();
 756        let processed_event_params = CreateProcessedStripeEventParams {
 757            stripe_event_id: event.id.to_string(),
 758            stripe_event_type: event_type_to_string(event.type_),
 759            stripe_event_created_timestamp: event.created,
 760        };
 761
 762        // If the event has happened too far in the past, we don't want to
 763        // process it and risk overwriting other more-recent updates.
 764        //
 765        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 766        let one_day = Duration::from_secs(24 * 60 * 60);
 767        let a_day_ago = Utc::now() - one_day;
 768        if a_day_ago.timestamp() > event.created {
 769            log::info!(
 770                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 771                event_id
 772            );
 773            app.db
 774                .create_processed_stripe_event(&processed_event_params)
 775                .await?;
 776
 777            return Ok(());
 778        }
 779
 780        let process_result = match event.type_ {
 781            EventType::CustomerCreated | EventType::CustomerUpdated => {
 782                handle_customer_event(app, stripe_client, event).await
 783            }
 784            EventType::CustomerSubscriptionCreated
 785            | EventType::CustomerSubscriptionUpdated
 786            | EventType::CustomerSubscriptionPaused
 787            | EventType::CustomerSubscriptionResumed
 788            | EventType::CustomerSubscriptionDeleted => {
 789                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 790            }
 791            _ => Ok(()),
 792        };
 793
 794        if let Some(()) = process_result
 795            .with_context(|| format!("failed to process event {event_id} successfully"))
 796            .log_err()
 797        {
 798            app.db
 799                .create_processed_stripe_event(&processed_event_params)
 800                .await?;
 801        }
 802    }
 803
 804    Ok(())
 805}
 806
 807async fn handle_customer_event(
 808    app: &Arc<AppState>,
 809    _stripe_client: &stripe::Client,
 810    event: stripe::Event,
 811) -> anyhow::Result<()> {
 812    let EventObject::Customer(customer) = event.data.object else {
 813        bail!("unexpected event payload for {}", event.id);
 814    };
 815
 816    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 817
 818    let Some(email) = customer.email else {
 819        log::info!("Stripe customer has no email: skipping");
 820        return Ok(());
 821    };
 822
 823    let Some(user) = app.db.get_user_by_email(&email).await? else {
 824        log::info!("no user found for email: skipping");
 825        return Ok(());
 826    };
 827
 828    if let Some(existing_customer) = app
 829        .db
 830        .get_billing_customer_by_stripe_customer_id(&customer.id)
 831        .await?
 832    {
 833        app.db
 834            .update_billing_customer(
 835                existing_customer.id,
 836                &UpdateBillingCustomerParams {
 837                    // For now we just leave the information as-is, as it is not
 838                    // likely to change.
 839                    ..Default::default()
 840                },
 841            )
 842            .await?;
 843    } else {
 844        app.db
 845            .create_billing_customer(&CreateBillingCustomerParams {
 846                user_id: user.id,
 847                stripe_customer_id: customer.id.to_string(),
 848            })
 849            .await?;
 850    }
 851
 852    Ok(())
 853}
 854
 855async fn handle_customer_subscription_event(
 856    app: &Arc<AppState>,
 857    rpc_server: &Arc<Server>,
 858    stripe_client: &stripe::Client,
 859    event: stripe::Event,
 860) -> anyhow::Result<()> {
 861    let EventObject::Subscription(subscription) = event.data.object else {
 862        bail!("unexpected event payload for {}", event.id);
 863    };
 864
 865    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 866
 867    let subscription_kind = maybe!(async {
 868        let stripe_billing = app.stripe_billing.clone()?;
 869
 870        let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
 871        let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
 872
 873        subscription.items.data.iter().find_map(|item| {
 874            let price = item.price.as_ref()?;
 875
 876            if price.id == zed_pro_price_id {
 877                Some(if subscription.status == SubscriptionStatus::Trialing {
 878                    SubscriptionKind::ZedProTrial
 879                } else {
 880                    SubscriptionKind::ZedPro
 881                })
 882            } else if price.id == zed_free_price_id {
 883                Some(SubscriptionKind::ZedFree)
 884            } else {
 885                None
 886            }
 887        })
 888    })
 889    .await;
 890
 891    let billing_customer =
 892        find_or_create_billing_customer(app, stripe_client, subscription.customer)
 893            .await?
 894            .ok_or_else(|| anyhow!("billing customer not found"))?;
 895
 896    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
 897        if subscription.status == SubscriptionStatus::Trialing {
 898            let current_period_start =
 899                DateTime::from_timestamp(subscription.current_period_start, 0)
 900                    .ok_or_else(|| anyhow!("No trial subscription period start"))?;
 901
 902            app.db
 903                .update_billing_customer(
 904                    billing_customer.id,
 905                    &UpdateBillingCustomerParams {
 906                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
 907                        ..Default::default()
 908                    },
 909                )
 910                .await?;
 911        }
 912    }
 913
 914    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 915        && subscription
 916            .cancellation_details
 917            .as_ref()
 918            .and_then(|details| details.reason)
 919            .map_or(false, |reason| {
 920                reason == CancellationDetailsReason::PaymentFailed
 921            });
 922
 923    if was_canceled_due_to_payment_failure {
 924        app.db
 925            .update_billing_customer(
 926                billing_customer.id,
 927                &UpdateBillingCustomerParams {
 928                    has_overdue_invoices: ActiveValue::set(true),
 929                    ..Default::default()
 930                },
 931            )
 932            .await?;
 933    }
 934
 935    if let Some(existing_subscription) = app
 936        .db
 937        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
 938        .await?
 939    {
 940        let llm_db = app
 941            .llm_db
 942            .clone()
 943            .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
 944
 945        let new_period_start_at =
 946            chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
 947                .ok_or_else(|| anyhow!("No subscription period start"))?;
 948        let new_period_end_at =
 949            chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
 950                .ok_or_else(|| anyhow!("No subscription period end"))?;
 951
 952        llm_db
 953            .transfer_existing_subscription_usage(
 954                billing_customer.user_id,
 955                &existing_subscription,
 956                subscription_kind,
 957                new_period_start_at,
 958                new_period_end_at,
 959            )
 960            .await?;
 961
 962        app.db
 963            .update_billing_subscription(
 964                existing_subscription.id,
 965                &UpdateBillingSubscriptionParams {
 966                    billing_customer_id: ActiveValue::set(billing_customer.id),
 967                    kind: ActiveValue::set(subscription_kind),
 968                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
 969                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
 970                    stripe_cancel_at: ActiveValue::set(
 971                        subscription
 972                            .cancel_at
 973                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 974                            .map(|time| time.naive_utc()),
 975                    ),
 976                    stripe_cancellation_reason: ActiveValue::set(
 977                        subscription
 978                            .cancellation_details
 979                            .and_then(|details| details.reason)
 980                            .map(|reason| reason.into()),
 981                    ),
 982                    stripe_current_period_start: ActiveValue::set(Some(
 983                        subscription.current_period_start,
 984                    )),
 985                    stripe_current_period_end: ActiveValue::set(Some(
 986                        subscription.current_period_end,
 987                    )),
 988                },
 989            )
 990            .await?;
 991    } else {
 992        // If the user already has an active billing subscription, ignore the
 993        // event and return an `Ok` to signal that it was processed
 994        // successfully.
 995        //
 996        // There is the possibility that this could cause us to not create a
 997        // subscription in the following scenario:
 998        //
 999        //   1. User has an active subscription A
1000        //   2. User cancels subscription A
1001        //   3. User creates a new subscription B
1002        //   4. We process the new subscription B before the cancellation of subscription A
1003        //   5. User ends up with no subscriptions
1004        //
1005        // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1006        if app
1007            .db
1008            .has_active_billing_subscription(billing_customer.user_id)
1009            .await?
1010        {
1011            log::info!(
1012                "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1013                user_id = billing_customer.user_id,
1014                subscription_id = subscription.id
1015            );
1016            return Ok(());
1017        }
1018
1019        app.db
1020            .create_billing_subscription(&CreateBillingSubscriptionParams {
1021                billing_customer_id: billing_customer.id,
1022                kind: subscription_kind,
1023                stripe_subscription_id: subscription.id.to_string(),
1024                stripe_subscription_status: subscription.status.into(),
1025                stripe_cancellation_reason: subscription
1026                    .cancellation_details
1027                    .and_then(|details| details.reason)
1028                    .map(|reason| reason.into()),
1029                stripe_current_period_start: Some(subscription.current_period_start),
1030                stripe_current_period_end: Some(subscription.current_period_end),
1031            })
1032            .await?;
1033    }
1034
1035    // When the user's subscription changes, we want to refresh their LLM tokens
1036    // to either grant/revoke access.
1037    rpc_server
1038        .refresh_llm_tokens_for_user(billing_customer.user_id)
1039        .await;
1040
1041    Ok(())
1042}
1043
1044#[derive(Debug, Deserialize)]
1045struct GetMonthlySpendParams {
1046    github_user_id: i32,
1047}
1048
1049#[derive(Debug, Serialize)]
1050struct GetMonthlySpendResponse {
1051    monthly_free_tier_spend_in_cents: u32,
1052    monthly_free_tier_allowance_in_cents: u32,
1053    monthly_spend_in_cents: u32,
1054}
1055
1056async fn get_monthly_spend(
1057    Extension(app): Extension<Arc<AppState>>,
1058    Query(params): Query<GetMonthlySpendParams>,
1059) -> Result<Json<GetMonthlySpendResponse>> {
1060    let user = app
1061        .db
1062        .get_user_by_github_user_id(params.github_user_id)
1063        .await?
1064        .ok_or_else(|| anyhow!("user not found"))?;
1065
1066    let Some(llm_db) = app.llm_db.clone() else {
1067        return Err(Error::http(
1068            StatusCode::NOT_IMPLEMENTED,
1069            "LLM database not available".into(),
1070        ));
1071    };
1072
1073    let free_tier = user
1074        .custom_llm_monthly_allowance_in_cents
1075        .map(|allowance| Cents(allowance as u32))
1076        .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1077
1078    let spending_for_month = llm_db
1079        .get_user_spending_for_month(user.id, Utc::now())
1080        .await?;
1081
1082    let free_tier_spend = Cents::min(spending_for_month, free_tier);
1083    let monthly_spend = spending_for_month.saturating_sub(free_tier);
1084
1085    Ok(Json(GetMonthlySpendResponse {
1086        monthly_free_tier_spend_in_cents: free_tier_spend.0,
1087        monthly_free_tier_allowance_in_cents: free_tier.0,
1088        monthly_spend_in_cents: monthly_spend.0,
1089    }))
1090}
1091
1092#[derive(Debug, Deserialize)]
1093struct GetCurrentUsageParams {
1094    github_user_id: i32,
1095}
1096
1097#[derive(Debug, Serialize)]
1098struct UsageCounts {
1099    pub used: i32,
1100    pub limit: Option<i32>,
1101    pub remaining: Option<i32>,
1102}
1103
1104#[derive(Debug, Serialize)]
1105struct ModelRequestUsage {
1106    pub model: String,
1107    pub mode: CompletionMode,
1108    pub requests: i32,
1109}
1110
1111#[derive(Debug, Serialize)]
1112struct GetCurrentUsageResponse {
1113    pub model_requests: UsageCounts,
1114    pub model_request_usage: Vec<ModelRequestUsage>,
1115    pub edit_predictions: UsageCounts,
1116}
1117
1118async fn get_current_usage(
1119    Extension(app): Extension<Arc<AppState>>,
1120    Query(params): Query<GetCurrentUsageParams>,
1121) -> Result<Json<GetCurrentUsageResponse>> {
1122    let user = app
1123        .db
1124        .get_user_by_github_user_id(params.github_user_id)
1125        .await?
1126        .ok_or_else(|| anyhow!("user not found"))?;
1127
1128    let Some(llm_db) = app.llm_db.clone() else {
1129        return Err(Error::http(
1130            StatusCode::NOT_IMPLEMENTED,
1131            "LLM database not available".into(),
1132        ));
1133    };
1134
1135    let empty_usage = GetCurrentUsageResponse {
1136        model_requests: UsageCounts {
1137            used: 0,
1138            limit: Some(0),
1139            remaining: Some(0),
1140        },
1141        model_request_usage: Vec::new(),
1142        edit_predictions: UsageCounts {
1143            used: 0,
1144            limit: Some(0),
1145            remaining: Some(0),
1146        },
1147    };
1148
1149    let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1150        return Ok(Json(empty_usage));
1151    };
1152
1153    let subscription_period = maybe!({
1154        let period_start_at = subscription.current_period_start_at()?;
1155        let period_end_at = subscription.current_period_end_at()?;
1156
1157        Some((period_start_at, period_end_at))
1158    });
1159
1160    let Some((period_start_at, period_end_at)) = subscription_period else {
1161        return Ok(Json(empty_usage));
1162    };
1163
1164    let usage = llm_db
1165        .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1166        .await?;
1167    let Some(usage) = usage else {
1168        return Ok(Json(empty_usage));
1169    };
1170
1171    let plan = match usage.plan {
1172        SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1173        SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1174        SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1175    };
1176
1177    let model_requests_limit = match plan.model_requests_limit() {
1178        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1179        zed_llm_client::UsageLimit::Unlimited => None,
1180    };
1181    let edit_prediction_limit = match plan.edit_predictions_limit() {
1182        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1183        zed_llm_client::UsageLimit::Unlimited => None,
1184    };
1185
1186    let subscription_usage_meters = llm_db
1187        .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1188        .await?;
1189
1190    let model_request_usage = subscription_usage_meters
1191        .into_iter()
1192        .filter_map(|(usage_meter, _usage)| {
1193            let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1194
1195            Some(ModelRequestUsage {
1196                model: model.name.clone(),
1197                mode: usage_meter.mode,
1198                requests: usage_meter.requests,
1199            })
1200        })
1201        .collect::<Vec<_>>();
1202
1203    Ok(Json(GetCurrentUsageResponse {
1204        model_requests: UsageCounts {
1205            used: usage.model_requests,
1206            limit: model_requests_limit,
1207            remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1208        },
1209        model_request_usage,
1210        edit_predictions: UsageCounts {
1211            used: usage.edit_predictions,
1212            limit: edit_prediction_limit,
1213            remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1214        },
1215    }))
1216}
1217
1218impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1219    fn from(value: SubscriptionStatus) -> Self {
1220        match value {
1221            SubscriptionStatus::Incomplete => Self::Incomplete,
1222            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1223            SubscriptionStatus::Trialing => Self::Trialing,
1224            SubscriptionStatus::Active => Self::Active,
1225            SubscriptionStatus::PastDue => Self::PastDue,
1226            SubscriptionStatus::Canceled => Self::Canceled,
1227            SubscriptionStatus::Unpaid => Self::Unpaid,
1228            SubscriptionStatus::Paused => Self::Paused,
1229        }
1230    }
1231}
1232
1233impl From<CancellationDetailsReason> for StripeCancellationReason {
1234    fn from(value: CancellationDetailsReason) -> Self {
1235        match value {
1236            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1237            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1238            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1239        }
1240    }
1241}
1242
1243/// Finds or creates a billing customer using the provided customer.
1244async fn find_or_create_billing_customer(
1245    app: &Arc<AppState>,
1246    stripe_client: &stripe::Client,
1247    customer_or_id: Expandable<Customer>,
1248) -> anyhow::Result<Option<billing_customer::Model>> {
1249    let customer_id = match &customer_or_id {
1250        Expandable::Id(id) => id,
1251        Expandable::Object(customer) => customer.id.as_ref(),
1252    };
1253
1254    // If we already have a billing customer record associated with the Stripe customer,
1255    // there's nothing more we need to do.
1256    if let Some(billing_customer) = app
1257        .db
1258        .get_billing_customer_by_stripe_customer_id(customer_id)
1259        .await?
1260    {
1261        return Ok(Some(billing_customer));
1262    }
1263
1264    // If all we have is a customer ID, resolve it to a full customer record by
1265    // hitting the Stripe API.
1266    let customer = match customer_or_id {
1267        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1268        Expandable::Object(customer) => *customer,
1269    };
1270
1271    let Some(email) = customer.email else {
1272        return Ok(None);
1273    };
1274
1275    let Some(user) = app.db.get_user_by_email(&email).await? else {
1276        return Ok(None);
1277    };
1278
1279    let billing_customer = app
1280        .db
1281        .create_billing_customer(&CreateBillingCustomerParams {
1282            user_id: user.id,
1283            stripe_customer_id: customer.id.to_string(),
1284        })
1285        .await?;
1286
1287    Ok(Some(billing_customer))
1288}
1289
1290const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1291
1292pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
1293    let Some(stripe_billing) = app.stripe_billing.clone() else {
1294        log::warn!("failed to retrieve Stripe billing object");
1295        return;
1296    };
1297    let Some(llm_db) = app.llm_db.clone() else {
1298        log::warn!("failed to retrieve LLM database");
1299        return;
1300    };
1301
1302    let executor = app.executor.clone();
1303    executor.spawn_detached({
1304        let executor = executor.clone();
1305        async move {
1306            loop {
1307                sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
1308                    .await
1309                    .context("failed to sync LLM usage to Stripe")
1310                    .trace_err();
1311                executor
1312                    .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
1313                    .await;
1314            }
1315        }
1316    });
1317}
1318
1319async fn sync_token_usage_with_stripe(
1320    app: &Arc<AppState>,
1321    llm_db: &Arc<LlmDatabase>,
1322    stripe_billing: &Arc<StripeBilling>,
1323) -> anyhow::Result<()> {
1324    let events = llm_db.get_billing_events().await?;
1325    let user_ids = events
1326        .iter()
1327        .map(|(event, _)| event.user_id)
1328        .collect::<HashSet<UserId>>();
1329    let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1330
1331    for (event, model) in events {
1332        let Some((stripe_db_customer, stripe_db_subscription)) =
1333            stripe_subscriptions.get(&event.user_id)
1334        else {
1335            tracing::warn!(
1336                user_id = event.user_id.0,
1337                "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1338            );
1339            continue;
1340        };
1341        let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1342            .stripe_subscription_id
1343            .parse()
1344            .context("failed to parse stripe subscription id from db")?;
1345        let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1346            .stripe_customer_id
1347            .parse()
1348            .context("failed to parse stripe customer id from db")?;
1349
1350        let stripe_model = stripe_billing
1351            .register_model_for_token_based_usage(&model)
1352            .await?;
1353        stripe_billing
1354            .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1355            .await?;
1356        stripe_billing
1357            .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
1358            .await?;
1359        llm_db.consume_billing_event(event.id).await?;
1360    }
1361
1362    Ok(())
1363}
1364
1365const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1366
1367pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1368    let Some(stripe_billing) = app.stripe_billing.clone() else {
1369        log::warn!("failed to retrieve Stripe billing object");
1370        return;
1371    };
1372    let Some(llm_db) = app.llm_db.clone() else {
1373        log::warn!("failed to retrieve LLM database");
1374        return;
1375    };
1376
1377    let executor = app.executor.clone();
1378    executor.spawn_detached({
1379        let executor = executor.clone();
1380        async move {
1381            loop {
1382                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1383                    .await
1384                    .context("failed to sync LLM request usage to Stripe")
1385                    .trace_err();
1386                executor
1387                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1388                    .await;
1389            }
1390        }
1391    });
1392}
1393
1394async fn sync_model_request_usage_with_stripe(
1395    app: &Arc<AppState>,
1396    llm_db: &Arc<LlmDatabase>,
1397    stripe_billing: &Arc<StripeBilling>,
1398) -> anyhow::Result<()> {
1399    let usage_meters = llm_db
1400        .get_current_subscription_usage_meters(Utc::now())
1401        .await?;
1402    let user_ids = usage_meters
1403        .iter()
1404        .map(|(_, usage)| usage.user_id)
1405        .collect::<HashSet<UserId>>();
1406    let billing_subscriptions = app
1407        .db
1408        .get_active_zed_pro_billing_subscriptions(user_ids)
1409        .await?;
1410
1411    let claude_3_5_sonnet = stripe_billing
1412        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1413        .await?;
1414    let claude_3_7_sonnet = stripe_billing
1415        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1416        .await?;
1417    let claude_3_7_sonnet_max = stripe_billing
1418        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1419        .await?;
1420
1421    for (usage_meter, usage) in usage_meters {
1422        maybe!(async {
1423            let Some((billing_customer, billing_subscription)) =
1424                billing_subscriptions.get(&usage.user_id)
1425            else {
1426                bail!(
1427                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1428                    usage.user_id
1429                );
1430            };
1431
1432            let stripe_customer_id = billing_customer
1433                .stripe_customer_id
1434                .parse::<stripe::CustomerId>()
1435                .context("failed to parse Stripe customer ID from database")?;
1436            let stripe_subscription_id = billing_subscription
1437                .stripe_subscription_id
1438                .parse::<stripe::SubscriptionId>()
1439                .context("failed to parse Stripe subscription ID from database")?;
1440
1441            let model = llm_db.model_by_id(usage_meter.model_id)?;
1442
1443            let (price, meter_event_name) = match model.name.as_str() {
1444                "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1445                "claude-3-7-sonnet" => match usage_meter.mode {
1446                    CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1447                    CompletionMode::Max => {
1448                        (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1449                    }
1450                },
1451                model_name => {
1452                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1453                }
1454            };
1455
1456            stripe_billing
1457                .subscribe_to_price(&stripe_subscription_id, price)
1458                .await?;
1459            stripe_billing
1460                .bill_model_request_usage(
1461                    &stripe_customer_id,
1462                    meter_event_name,
1463                    usage_meter.requests,
1464                )
1465                .await?;
1466
1467            Ok(())
1468        })
1469        .await
1470        .log_err();
1471    }
1472
1473    Ok(())
1474}