billing.rs

   1use anyhow::{Context, anyhow, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  19    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  20    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
  21    EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
  22};
  23use util::{ResultExt, maybe};
  24
  25use crate::api::events::SnowflakeRow;
  26use crate::db::billing_subscription::{
  27    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  28};
  29use crate::llm::db::subscription_usage_meter::CompletionMode;
  30use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
  31use crate::rpc::{ResultExt as _, Server};
  32use crate::{AppState, Cents, Error, Result};
  33use crate::{db::UserId, llm::db::LlmDatabase};
  34use crate::{
  35    db::{
  36        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  37        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  38        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  39    },
  40    stripe_billing::StripeBilling,
  41};
  42
  43pub fn router() -> Router {
  44    Router::new()
  45        .route(
  46            "/billing/preferences",
  47            get(get_billing_preferences).put(update_billing_preferences),
  48        )
  49        .route(
  50            "/billing/subscriptions",
  51            get(list_billing_subscriptions).post(create_billing_subscription),
  52        )
  53        .route(
  54            "/billing/subscriptions/manage",
  55            post(manage_billing_subscription),
  56        )
  57        .route("/billing/monthly_spend", get(get_monthly_spend))
  58        .route("/billing/usage", get(get_current_usage))
  59}
  60
  61#[derive(Debug, Deserialize)]
  62struct GetBillingPreferencesParams {
  63    github_user_id: i32,
  64}
  65
  66#[derive(Debug, Serialize)]
  67struct BillingPreferencesResponse {
  68    max_monthly_llm_usage_spending_in_cents: i32,
  69    model_request_overages_enabled: bool,
  70    model_request_overages_spend_limit_in_cents: i32,
  71}
  72
  73async fn get_billing_preferences(
  74    Extension(app): Extension<Arc<AppState>>,
  75    Query(params): Query<GetBillingPreferencesParams>,
  76) -> Result<Json<BillingPreferencesResponse>> {
  77    let user = app
  78        .db
  79        .get_user_by_github_user_id(params.github_user_id)
  80        .await?
  81        .ok_or_else(|| anyhow!("user not found"))?;
  82
  83    let preferences = app.db.get_billing_preferences(user.id).await?;
  84
  85    Ok(Json(BillingPreferencesResponse {
  86        max_monthly_llm_usage_spending_in_cents: preferences
  87            .as_ref()
  88            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
  89                preferences.max_monthly_llm_usage_spending_in_cents
  90            }),
  91        model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
  92            preferences.model_request_overages_enabled
  93        }),
  94        model_request_overages_spend_limit_in_cents: preferences
  95            .as_ref()
  96            .map_or(0, |preferences| {
  97                preferences.model_request_overages_spend_limit_in_cents
  98            }),
  99    }))
 100}
 101
 102#[derive(Debug, Deserialize)]
 103struct UpdateBillingPreferencesBody {
 104    github_user_id: i32,
 105    #[serde(default)]
 106    max_monthly_llm_usage_spending_in_cents: i32,
 107    #[serde(default)]
 108    model_request_overages_enabled: bool,
 109    #[serde(default)]
 110    model_request_overages_spend_limit_in_cents: i32,
 111}
 112
 113async fn update_billing_preferences(
 114    Extension(app): Extension<Arc<AppState>>,
 115    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
 116    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 117) -> Result<Json<BillingPreferencesResponse>> {
 118    let user = app
 119        .db
 120        .get_user_by_github_user_id(body.github_user_id)
 121        .await?
 122        .ok_or_else(|| anyhow!("user not found"))?;
 123
 124    let max_monthly_llm_usage_spending_in_cents =
 125        body.max_monthly_llm_usage_spending_in_cents.max(0);
 126    let model_request_overages_spend_limit_in_cents =
 127        body.model_request_overages_spend_limit_in_cents.max(0);
 128
 129    let billing_preferences =
 130        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 131            app.db
 132                .update_billing_preferences(
 133                    user.id,
 134                    &UpdateBillingPreferencesParams {
 135                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 136                            max_monthly_llm_usage_spending_in_cents,
 137                        ),
 138                        model_request_overages_enabled: ActiveValue::set(
 139                            body.model_request_overages_enabled,
 140                        ),
 141                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 142                            model_request_overages_spend_limit_in_cents,
 143                        ),
 144                    },
 145                )
 146                .await?
 147        } else {
 148            app.db
 149                .create_billing_preferences(
 150                    user.id,
 151                    &crate::db::CreateBillingPreferencesParams {
 152                        max_monthly_llm_usage_spending_in_cents,
 153                        model_request_overages_enabled: body.model_request_overages_enabled,
 154                        model_request_overages_spend_limit_in_cents,
 155                    },
 156                )
 157                .await?
 158        };
 159
 160    SnowflakeRow::new(
 161        "Billing Preferences Updated",
 162        Some(user.metrics_id),
 163        user.admin,
 164        None,
 165        json!({
 166            "user_id": user.id,
 167            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 168            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 169            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 170        }),
 171    )
 172    .write(&app.kinesis_client, &app.config.kinesis_stream)
 173    .await
 174    .log_err();
 175
 176    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 177
 178    Ok(Json(BillingPreferencesResponse {
 179        max_monthly_llm_usage_spending_in_cents: billing_preferences
 180            .max_monthly_llm_usage_spending_in_cents,
 181        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 182        model_request_overages_spend_limit_in_cents: billing_preferences
 183            .model_request_overages_spend_limit_in_cents,
 184    }))
 185}
 186
 187#[derive(Debug, Deserialize)]
 188struct ListBillingSubscriptionsParams {
 189    github_user_id: i32,
 190}
 191
 192#[derive(Debug, Serialize)]
 193struct BillingSubscriptionJson {
 194    id: BillingSubscriptionId,
 195    name: String,
 196    status: StripeSubscriptionStatus,
 197    trial_end_at: Option<String>,
 198    cancel_at: Option<String>,
 199    /// Whether this subscription can be canceled.
 200    is_cancelable: bool,
 201}
 202
 203#[derive(Debug, Serialize)]
 204struct ListBillingSubscriptionsResponse {
 205    subscriptions: Vec<BillingSubscriptionJson>,
 206}
 207
 208async fn list_billing_subscriptions(
 209    Extension(app): Extension<Arc<AppState>>,
 210    Query(params): Query<ListBillingSubscriptionsParams>,
 211) -> Result<Json<ListBillingSubscriptionsResponse>> {
 212    let user = app
 213        .db
 214        .get_user_by_github_user_id(params.github_user_id)
 215        .await?
 216        .ok_or_else(|| anyhow!("user not found"))?;
 217
 218    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 219
 220    Ok(Json(ListBillingSubscriptionsResponse {
 221        subscriptions: subscriptions
 222            .into_iter()
 223            .map(|subscription| BillingSubscriptionJson {
 224                id: subscription.id,
 225                name: match subscription.kind {
 226                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 227                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
 228                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
 229                    None => "Zed LLM Usage".to_string(),
 230                },
 231                status: subscription.stripe_subscription_status,
 232                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
 233                    maybe!({
 234                        let end_at = subscription.stripe_current_period_end?;
 235                        let end_at = DateTime::from_timestamp(end_at, 0)?;
 236
 237                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
 238                    })
 239                } else {
 240                    None
 241                },
 242                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 243                    cancel_at
 244                        .and_utc()
 245                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 246                }),
 247                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
 248                    && subscription.stripe_cancel_at.is_none(),
 249            })
 250            .collect(),
 251    }))
 252}
 253
 254#[derive(Debug, Clone, Copy, Deserialize)]
 255#[serde(rename_all = "snake_case")]
 256enum ProductCode {
 257    ZedPro,
 258    ZedProTrial,
 259}
 260
 261#[derive(Debug, Deserialize)]
 262struct CreateBillingSubscriptionBody {
 263    github_user_id: i32,
 264    product: Option<ProductCode>,
 265}
 266
 267#[derive(Debug, Serialize)]
 268struct CreateBillingSubscriptionResponse {
 269    checkout_session_url: String,
 270}
 271
 272/// Initiates a Stripe Checkout session for creating a billing subscription.
 273async fn create_billing_subscription(
 274    Extension(app): Extension<Arc<AppState>>,
 275    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 276) -> Result<Json<CreateBillingSubscriptionResponse>> {
 277    let user = app
 278        .db
 279        .get_user_by_github_user_id(body.github_user_id)
 280        .await?
 281        .ok_or_else(|| anyhow!("user not found"))?;
 282
 283    let Some(stripe_client) = app.stripe_client.clone() else {
 284        log::error!("failed to retrieve Stripe client");
 285        Err(Error::http(
 286            StatusCode::NOT_IMPLEMENTED,
 287            "not supported".into(),
 288        ))?
 289    };
 290    let Some(stripe_billing) = app.stripe_billing.clone() else {
 291        log::error!("failed to retrieve Stripe billing object");
 292        Err(Error::http(
 293            StatusCode::NOT_IMPLEMENTED,
 294            "not supported".into(),
 295        ))?
 296    };
 297    let Some(llm_db) = app.llm_db.clone() else {
 298        log::error!("failed to retrieve LLM database");
 299        Err(Error::http(
 300            StatusCode::NOT_IMPLEMENTED,
 301            "not supported".into(),
 302        ))?
 303    };
 304
 305    if app.db.has_active_billing_subscription(user.id).await? {
 306        return Err(Error::http(
 307            StatusCode::CONFLICT,
 308            "user already has an active subscription".into(),
 309        ));
 310    }
 311
 312    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 313    if let Some(existing_billing_customer) = &existing_billing_customer {
 314        if existing_billing_customer.has_overdue_invoices {
 315            return Err(Error::http(
 316                StatusCode::PAYMENT_REQUIRED,
 317                "user has overdue invoices".into(),
 318            ));
 319        }
 320    }
 321
 322    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 323        CustomerId::from_str(&existing_customer.stripe_customer_id)
 324            .context("failed to parse customer ID")?
 325    } else {
 326        let existing_customer = if let Some(email) = user.email_address.as_deref() {
 327            let customers = Customer::list(
 328                &stripe_client,
 329                &stripe::ListCustomers {
 330                    email: Some(email),
 331                    ..Default::default()
 332                },
 333            )
 334            .await?;
 335
 336            customers.data.first().cloned()
 337        } else {
 338            None
 339        };
 340
 341        if let Some(existing_customer) = existing_customer {
 342            existing_customer.id
 343        } else {
 344            let customer = Customer::create(
 345                &stripe_client,
 346                CreateCustomer {
 347                    email: user.email_address.as_deref(),
 348                    ..Default::default()
 349                },
 350            )
 351            .await?;
 352
 353            customer.id
 354        }
 355    };
 356
 357    let success_url = format!(
 358        "{}/account?checkout_complete=1",
 359        app.config.zed_dot_dev_url()
 360    );
 361
 362    let checkout_session_url = match body.product {
 363        Some(ProductCode::ZedPro) => {
 364            stripe_billing
 365                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
 366                .await?
 367        }
 368        Some(ProductCode::ZedProTrial) => {
 369            if let Some(existing_billing_customer) = &existing_billing_customer {
 370                if existing_billing_customer.trial_started_at.is_some() {
 371                    return Err(Error::http(
 372                        StatusCode::FORBIDDEN,
 373                        "user already used free trial".into(),
 374                    ));
 375                }
 376            }
 377
 378            let feature_flags = app.db.get_user_flags(user.id).await?;
 379
 380            stripe_billing
 381                .checkout_with_zed_pro_trial(
 382                    customer_id,
 383                    &user.github_login,
 384                    feature_flags,
 385                    &success_url,
 386                )
 387                .await?
 388        }
 389        None => {
 390            let default_model = llm_db.model(
 391                zed_llm_client::LanguageModelProvider::Anthropic,
 392                "claude-3-7-sonnet",
 393            )?;
 394            let stripe_model = stripe_billing
 395                .register_model_for_token_based_usage(default_model)
 396                .await?;
 397            stripe_billing
 398                .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
 399                .await?
 400        }
 401    };
 402
 403    Ok(Json(CreateBillingSubscriptionResponse {
 404        checkout_session_url,
 405    }))
 406}
 407
 408#[derive(Debug, PartialEq, Deserialize)]
 409#[serde(rename_all = "snake_case")]
 410enum ManageSubscriptionIntent {
 411    /// The user intends to manage their subscription.
 412    ///
 413    /// This will open the Stripe billing portal without putting the user in a specific flow.
 414    ManageSubscription,
 415    /// The user intends to upgrade to Zed Pro.
 416    UpgradeToPro,
 417    /// The user intends to cancel their subscription.
 418    Cancel,
 419    /// The user intends to stop the cancellation of their subscription.
 420    StopCancellation,
 421}
 422
 423#[derive(Debug, Deserialize)]
 424struct ManageBillingSubscriptionBody {
 425    github_user_id: i32,
 426    intent: ManageSubscriptionIntent,
 427    /// The ID of the subscription to manage.
 428    subscription_id: BillingSubscriptionId,
 429}
 430
 431#[derive(Debug, Serialize)]
 432struct ManageBillingSubscriptionResponse {
 433    billing_portal_session_url: Option<String>,
 434}
 435
 436/// Initiates a Stripe customer portal session for managing a billing subscription.
 437async fn manage_billing_subscription(
 438    Extension(app): Extension<Arc<AppState>>,
 439    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 440) -> Result<Json<ManageBillingSubscriptionResponse>> {
 441    let user = app
 442        .db
 443        .get_user_by_github_user_id(body.github_user_id)
 444        .await?
 445        .ok_or_else(|| anyhow!("user not found"))?;
 446
 447    let Some(stripe_client) = app.stripe_client.clone() else {
 448        log::error!("failed to retrieve Stripe client");
 449        Err(Error::http(
 450            StatusCode::NOT_IMPLEMENTED,
 451            "not supported".into(),
 452        ))?
 453    };
 454
 455    let Some(stripe_billing) = app.stripe_billing.clone() else {
 456        log::error!("failed to retrieve Stripe billing object");
 457        Err(Error::http(
 458            StatusCode::NOT_IMPLEMENTED,
 459            "not supported".into(),
 460        ))?
 461    };
 462
 463    let customer = app
 464        .db
 465        .get_billing_customer_by_user_id(user.id)
 466        .await?
 467        .ok_or_else(|| anyhow!("billing customer not found"))?;
 468    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 469        .context("failed to parse customer ID")?;
 470
 471    let subscription = app
 472        .db
 473        .get_billing_subscription_by_id(body.subscription_id)
 474        .await?
 475        .ok_or_else(|| anyhow!("subscription not found"))?;
 476    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 477        .context("failed to parse subscription ID")?;
 478
 479    if body.intent == ManageSubscriptionIntent::StopCancellation {
 480        let updated_stripe_subscription = Subscription::update(
 481            &stripe_client,
 482            &subscription_id,
 483            stripe::UpdateSubscription {
 484                cancel_at_period_end: Some(false),
 485                ..Default::default()
 486            },
 487        )
 488        .await?;
 489
 490        app.db
 491            .update_billing_subscription(
 492                subscription.id,
 493                &UpdateBillingSubscriptionParams {
 494                    stripe_cancel_at: ActiveValue::set(
 495                        updated_stripe_subscription
 496                            .cancel_at
 497                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 498                            .map(|time| time.naive_utc()),
 499                    ),
 500                    ..Default::default()
 501                },
 502            )
 503            .await?;
 504
 505        return Ok(Json(ManageBillingSubscriptionResponse {
 506            billing_portal_session_url: None,
 507        }));
 508    }
 509
 510    let flow = match body.intent {
 511        ManageSubscriptionIntent::ManageSubscription => None,
 512        ManageSubscriptionIntent::UpgradeToPro => {
 513            let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
 514            let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
 515
 516            let stripe_subscription =
 517                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 518
 519            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 520                && stripe_subscription.items.data.iter().any(|item| {
 521                    item.price
 522                        .as_ref()
 523                        .map_or(false, |price| price.id == zed_pro_price_id)
 524                });
 525            if is_on_zed_pro_trial {
 526                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 527                Subscription::update(
 528                    &stripe_client,
 529                    &stripe_subscription.id,
 530                    stripe::UpdateSubscription {
 531                        trial_end: Some(stripe::Scheduled::now()),
 532                        ..Default::default()
 533                    },
 534                )
 535                .await?;
 536
 537                return Ok(Json(ManageBillingSubscriptionResponse {
 538                    billing_portal_session_url: None,
 539                }));
 540            }
 541
 542            let subscription_item_to_update = stripe_subscription
 543                .items
 544                .data
 545                .iter()
 546                .find_map(|item| {
 547                    let price = item.price.as_ref()?;
 548
 549                    if price.id == zed_free_price_id {
 550                        Some(item.id.clone())
 551                    } else {
 552                        None
 553                    }
 554                })
 555                .ok_or_else(|| anyhow!("No subscription item to update"))?;
 556
 557            Some(CreateBillingPortalSessionFlowData {
 558                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 559                subscription_update_confirm: Some(
 560                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 561                        subscription: subscription.stripe_subscription_id,
 562                        items: vec![
 563                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 564                                id: subscription_item_to_update.to_string(),
 565                                price: Some(zed_pro_price_id.to_string()),
 566                                quantity: Some(1),
 567                            },
 568                        ],
 569                        discounts: None,
 570                    },
 571                ),
 572                ..Default::default()
 573            })
 574        }
 575        ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
 576            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 577            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 578                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 579                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 580                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 581                }),
 582                ..Default::default()
 583            }),
 584            subscription_cancel: Some(
 585                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 586                    subscription: subscription.stripe_subscription_id,
 587                    retention: None,
 588                },
 589            ),
 590            ..Default::default()
 591        }),
 592        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 593    };
 594
 595    let mut params = CreateBillingPortalSession::new(customer_id);
 596    params.flow_data = flow;
 597    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 598    params.return_url = Some(&return_url);
 599
 600    let session = BillingPortalSession::create(&stripe_client, params).await?;
 601
 602    Ok(Json(ManageBillingSubscriptionResponse {
 603        billing_portal_session_url: Some(session.url),
 604    }))
 605}
 606
 607/// The amount of time we wait in between each poll of Stripe events.
 608///
 609/// This value should strike a balance between:
 610///   1. Being short enough that we update quickly when something in Stripe changes
 611///   2. Being long enough that we don't eat into our rate limits.
 612///
 613/// As a point of reference, the Sequin folks say they have this at **500ms**:
 614///
 615/// > We poll the Stripe /events endpoint every 500ms per account
 616/// >
 617/// > — https://blog.sequinstream.com/events-not-webhooks/
 618const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 619
 620/// The maximum number of events to return per page.
 621///
 622/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 623///
 624/// > Limit can range between 1 and 100, and the default is 10.
 625const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 626
 627/// The number of pages consisting entirely of already-processed events that we
 628/// will see before we stop retrieving events.
 629///
 630/// This is used to prevent over-fetching the Stripe events API for events we've
 631/// already seen and processed.
 632const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 633
 634/// Polls the Stripe events API periodically to reconcile the records in our
 635/// database with the data in Stripe.
 636pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 637    let Some(stripe_client) = app.stripe_client.clone() else {
 638        log::warn!("failed to retrieve Stripe client");
 639        return;
 640    };
 641
 642    let executor = app.executor.clone();
 643    executor.spawn_detached({
 644        let executor = executor.clone();
 645        async move {
 646            loop {
 647                poll_stripe_events(&app, &rpc_server, &stripe_client)
 648                    .await
 649                    .log_err();
 650
 651                executor.sleep(POLL_EVENTS_INTERVAL).await;
 652            }
 653        }
 654    });
 655}
 656
 657async fn poll_stripe_events(
 658    app: &Arc<AppState>,
 659    rpc_server: &Arc<Server>,
 660    stripe_client: &stripe::Client,
 661) -> anyhow::Result<()> {
 662    fn event_type_to_string(event_type: EventType) -> String {
 663        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 664        // so we need to unquote it.
 665        event_type.to_string().trim_matches('"').to_string()
 666    }
 667
 668    let event_types = [
 669        EventType::CustomerCreated,
 670        EventType::CustomerUpdated,
 671        EventType::CustomerSubscriptionCreated,
 672        EventType::CustomerSubscriptionUpdated,
 673        EventType::CustomerSubscriptionPaused,
 674        EventType::CustomerSubscriptionResumed,
 675        EventType::CustomerSubscriptionDeleted,
 676    ]
 677    .into_iter()
 678    .map(event_type_to_string)
 679    .collect::<Vec<_>>();
 680
 681    let mut pages_of_already_processed_events = 0;
 682    let mut unprocessed_events = Vec::new();
 683
 684    log::info!(
 685        "Stripe events: starting retrieval for {}",
 686        event_types.join(", ")
 687    );
 688    let mut params = ListEvents::new();
 689    params.types = Some(event_types.clone());
 690    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 691
 692    let mut event_pages = stripe::Event::list(&stripe_client, &params)
 693        .await?
 694        .paginate(params);
 695
 696    loop {
 697        let processed_event_ids = {
 698            let event_ids = event_pages
 699                .page
 700                .data
 701                .iter()
 702                .map(|event| event.id.as_str())
 703                .collect::<Vec<_>>();
 704            app.db
 705                .get_processed_stripe_events_by_event_ids(&event_ids)
 706                .await?
 707                .into_iter()
 708                .map(|event| event.stripe_event_id)
 709                .collect::<Vec<_>>()
 710        };
 711
 712        let mut processed_events_in_page = 0;
 713        let events_in_page = event_pages.page.data.len();
 714        for event in &event_pages.page.data {
 715            if processed_event_ids.contains(&event.id.to_string()) {
 716                processed_events_in_page += 1;
 717                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 718            } else {
 719                unprocessed_events.push(event.clone());
 720            }
 721        }
 722
 723        if processed_events_in_page == events_in_page {
 724            pages_of_already_processed_events += 1;
 725        }
 726
 727        if event_pages.page.has_more {
 728            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 729            {
 730                log::info!(
 731                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 732                );
 733                break;
 734            } else {
 735                log::info!("Stripe events: retrieving next page");
 736                event_pages = event_pages.next(&stripe_client).await?;
 737            }
 738        } else {
 739            break;
 740        }
 741    }
 742
 743    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 744
 745    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 746    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 747
 748    for event in unprocessed_events {
 749        let event_id = event.id.clone();
 750        let processed_event_params = CreateProcessedStripeEventParams {
 751            stripe_event_id: event.id.to_string(),
 752            stripe_event_type: event_type_to_string(event.type_),
 753            stripe_event_created_timestamp: event.created,
 754        };
 755
 756        // If the event has happened too far in the past, we don't want to
 757        // process it and risk overwriting other more-recent updates.
 758        //
 759        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 760        let one_day = Duration::from_secs(24 * 60 * 60);
 761        let a_day_ago = Utc::now() - one_day;
 762        if a_day_ago.timestamp() > event.created {
 763            log::info!(
 764                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 765                event_id
 766            );
 767            app.db
 768                .create_processed_stripe_event(&processed_event_params)
 769                .await?;
 770
 771            return Ok(());
 772        }
 773
 774        let process_result = match event.type_ {
 775            EventType::CustomerCreated | EventType::CustomerUpdated => {
 776                handle_customer_event(app, stripe_client, event).await
 777            }
 778            EventType::CustomerSubscriptionCreated
 779            | EventType::CustomerSubscriptionUpdated
 780            | EventType::CustomerSubscriptionPaused
 781            | EventType::CustomerSubscriptionResumed
 782            | EventType::CustomerSubscriptionDeleted => {
 783                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 784            }
 785            _ => Ok(()),
 786        };
 787
 788        if let Some(()) = process_result
 789            .with_context(|| format!("failed to process event {event_id} successfully"))
 790            .log_err()
 791        {
 792            app.db
 793                .create_processed_stripe_event(&processed_event_params)
 794                .await?;
 795        }
 796    }
 797
 798    Ok(())
 799}
 800
 801async fn handle_customer_event(
 802    app: &Arc<AppState>,
 803    _stripe_client: &stripe::Client,
 804    event: stripe::Event,
 805) -> anyhow::Result<()> {
 806    let EventObject::Customer(customer) = event.data.object else {
 807        bail!("unexpected event payload for {}", event.id);
 808    };
 809
 810    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 811
 812    let Some(email) = customer.email else {
 813        log::info!("Stripe customer has no email: skipping");
 814        return Ok(());
 815    };
 816
 817    let Some(user) = app.db.get_user_by_email(&email).await? else {
 818        log::info!("no user found for email: skipping");
 819        return Ok(());
 820    };
 821
 822    if let Some(existing_customer) = app
 823        .db
 824        .get_billing_customer_by_stripe_customer_id(&customer.id)
 825        .await?
 826    {
 827        app.db
 828            .update_billing_customer(
 829                existing_customer.id,
 830                &UpdateBillingCustomerParams {
 831                    // For now we just leave the information as-is, as it is not
 832                    // likely to change.
 833                    ..Default::default()
 834                },
 835            )
 836            .await?;
 837    } else {
 838        app.db
 839            .create_billing_customer(&CreateBillingCustomerParams {
 840                user_id: user.id,
 841                stripe_customer_id: customer.id.to_string(),
 842            })
 843            .await?;
 844    }
 845
 846    Ok(())
 847}
 848
 849async fn handle_customer_subscription_event(
 850    app: &Arc<AppState>,
 851    rpc_server: &Arc<Server>,
 852    stripe_client: &stripe::Client,
 853    event: stripe::Event,
 854) -> anyhow::Result<()> {
 855    let EventObject::Subscription(subscription) = event.data.object else {
 856        bail!("unexpected event payload for {}", event.id);
 857    };
 858
 859    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 860
 861    let subscription_kind = maybe!(async {
 862        let stripe_billing = app.stripe_billing.clone()?;
 863
 864        let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
 865        let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
 866
 867        subscription.items.data.iter().find_map(|item| {
 868            let price = item.price.as_ref()?;
 869
 870            if price.id == zed_pro_price_id {
 871                Some(if subscription.status == SubscriptionStatus::Trialing {
 872                    SubscriptionKind::ZedProTrial
 873                } else {
 874                    SubscriptionKind::ZedPro
 875                })
 876            } else if price.id == zed_free_price_id {
 877                Some(SubscriptionKind::ZedFree)
 878            } else {
 879                None
 880            }
 881        })
 882    })
 883    .await;
 884
 885    let billing_customer =
 886        find_or_create_billing_customer(app, stripe_client, subscription.customer)
 887            .await?
 888            .ok_or_else(|| anyhow!("billing customer not found"))?;
 889
 890    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
 891        if subscription.status == SubscriptionStatus::Trialing {
 892            let current_period_start =
 893                DateTime::from_timestamp(subscription.current_period_start, 0)
 894                    .ok_or_else(|| anyhow!("No trial subscription period start"))?;
 895
 896            app.db
 897                .update_billing_customer(
 898                    billing_customer.id,
 899                    &UpdateBillingCustomerParams {
 900                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
 901                        ..Default::default()
 902                    },
 903                )
 904                .await?;
 905        }
 906    }
 907
 908    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 909        && subscription
 910            .cancellation_details
 911            .as_ref()
 912            .and_then(|details| details.reason)
 913            .map_or(false, |reason| {
 914                reason == CancellationDetailsReason::PaymentFailed
 915            });
 916
 917    if was_canceled_due_to_payment_failure {
 918        app.db
 919            .update_billing_customer(
 920                billing_customer.id,
 921                &UpdateBillingCustomerParams {
 922                    has_overdue_invoices: ActiveValue::set(true),
 923                    ..Default::default()
 924                },
 925            )
 926            .await?;
 927    }
 928
 929    if let Some(existing_subscription) = app
 930        .db
 931        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
 932        .await?
 933    {
 934        let llm_db = app
 935            .llm_db
 936            .clone()
 937            .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
 938
 939        let new_period_start_at =
 940            chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
 941                .ok_or_else(|| anyhow!("No subscription period start"))?;
 942        let new_period_end_at =
 943            chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
 944                .ok_or_else(|| anyhow!("No subscription period end"))?;
 945
 946        llm_db
 947            .transfer_existing_subscription_usage(
 948                billing_customer.user_id,
 949                &existing_subscription,
 950                subscription_kind,
 951                new_period_start_at,
 952                new_period_end_at,
 953            )
 954            .await?;
 955
 956        app.db
 957            .update_billing_subscription(
 958                existing_subscription.id,
 959                &UpdateBillingSubscriptionParams {
 960                    billing_customer_id: ActiveValue::set(billing_customer.id),
 961                    kind: ActiveValue::set(subscription_kind),
 962                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
 963                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
 964                    stripe_cancel_at: ActiveValue::set(
 965                        subscription
 966                            .cancel_at
 967                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 968                            .map(|time| time.naive_utc()),
 969                    ),
 970                    stripe_cancellation_reason: ActiveValue::set(
 971                        subscription
 972                            .cancellation_details
 973                            .and_then(|details| details.reason)
 974                            .map(|reason| reason.into()),
 975                    ),
 976                    stripe_current_period_start: ActiveValue::set(Some(
 977                        subscription.current_period_start,
 978                    )),
 979                    stripe_current_period_end: ActiveValue::set(Some(
 980                        subscription.current_period_end,
 981                    )),
 982                },
 983            )
 984            .await?;
 985    } else {
 986        // If the user already has an active billing subscription, ignore the
 987        // event and return an `Ok` to signal that it was processed
 988        // successfully.
 989        //
 990        // There is the possibility that this could cause us to not create a
 991        // subscription in the following scenario:
 992        //
 993        //   1. User has an active subscription A
 994        //   2. User cancels subscription A
 995        //   3. User creates a new subscription B
 996        //   4. We process the new subscription B before the cancellation of subscription A
 997        //   5. User ends up with no subscriptions
 998        //
 999        // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1000        if app
1001            .db
1002            .has_active_billing_subscription(billing_customer.user_id)
1003            .await?
1004        {
1005            log::info!(
1006                "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1007                user_id = billing_customer.user_id,
1008                subscription_id = subscription.id
1009            );
1010            return Ok(());
1011        }
1012
1013        app.db
1014            .create_billing_subscription(&CreateBillingSubscriptionParams {
1015                billing_customer_id: billing_customer.id,
1016                kind: subscription_kind,
1017                stripe_subscription_id: subscription.id.to_string(),
1018                stripe_subscription_status: subscription.status.into(),
1019                stripe_cancellation_reason: subscription
1020                    .cancellation_details
1021                    .and_then(|details| details.reason)
1022                    .map(|reason| reason.into()),
1023                stripe_current_period_start: Some(subscription.current_period_start),
1024                stripe_current_period_end: Some(subscription.current_period_end),
1025            })
1026            .await?;
1027    }
1028
1029    // When the user's subscription changes, we want to refresh their LLM tokens
1030    // to either grant/revoke access.
1031    rpc_server
1032        .refresh_llm_tokens_for_user(billing_customer.user_id)
1033        .await;
1034
1035    Ok(())
1036}
1037
1038#[derive(Debug, Deserialize)]
1039struct GetMonthlySpendParams {
1040    github_user_id: i32,
1041}
1042
1043#[derive(Debug, Serialize)]
1044struct GetMonthlySpendResponse {
1045    monthly_free_tier_spend_in_cents: u32,
1046    monthly_free_tier_allowance_in_cents: u32,
1047    monthly_spend_in_cents: u32,
1048}
1049
1050async fn get_monthly_spend(
1051    Extension(app): Extension<Arc<AppState>>,
1052    Query(params): Query<GetMonthlySpendParams>,
1053) -> Result<Json<GetMonthlySpendResponse>> {
1054    let user = app
1055        .db
1056        .get_user_by_github_user_id(params.github_user_id)
1057        .await?
1058        .ok_or_else(|| anyhow!("user not found"))?;
1059
1060    let Some(llm_db) = app.llm_db.clone() else {
1061        return Err(Error::http(
1062            StatusCode::NOT_IMPLEMENTED,
1063            "LLM database not available".into(),
1064        ));
1065    };
1066
1067    let free_tier = user
1068        .custom_llm_monthly_allowance_in_cents
1069        .map(|allowance| Cents(allowance as u32))
1070        .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1071
1072    let spending_for_month = llm_db
1073        .get_user_spending_for_month(user.id, Utc::now())
1074        .await?;
1075
1076    let free_tier_spend = Cents::min(spending_for_month, free_tier);
1077    let monthly_spend = spending_for_month.saturating_sub(free_tier);
1078
1079    Ok(Json(GetMonthlySpendResponse {
1080        monthly_free_tier_spend_in_cents: free_tier_spend.0,
1081        monthly_free_tier_allowance_in_cents: free_tier.0,
1082        monthly_spend_in_cents: monthly_spend.0,
1083    }))
1084}
1085
1086#[derive(Debug, Deserialize)]
1087struct GetCurrentUsageParams {
1088    github_user_id: i32,
1089}
1090
1091#[derive(Debug, Serialize)]
1092struct UsageCounts {
1093    pub used: i32,
1094    pub limit: Option<i32>,
1095    pub remaining: Option<i32>,
1096}
1097
1098#[derive(Debug, Serialize)]
1099struct ModelRequestUsage {
1100    pub model: String,
1101    pub mode: CompletionMode,
1102    pub requests: i32,
1103}
1104
1105#[derive(Debug, Serialize)]
1106struct GetCurrentUsageResponse {
1107    pub model_requests: UsageCounts,
1108    pub model_request_usage: Vec<ModelRequestUsage>,
1109    pub edit_predictions: UsageCounts,
1110}
1111
1112async fn get_current_usage(
1113    Extension(app): Extension<Arc<AppState>>,
1114    Query(params): Query<GetCurrentUsageParams>,
1115) -> Result<Json<GetCurrentUsageResponse>> {
1116    let user = app
1117        .db
1118        .get_user_by_github_user_id(params.github_user_id)
1119        .await?
1120        .ok_or_else(|| anyhow!("user not found"))?;
1121
1122    let Some(llm_db) = app.llm_db.clone() else {
1123        return Err(Error::http(
1124            StatusCode::NOT_IMPLEMENTED,
1125            "LLM database not available".into(),
1126        ));
1127    };
1128
1129    let empty_usage = GetCurrentUsageResponse {
1130        model_requests: UsageCounts {
1131            used: 0,
1132            limit: Some(0),
1133            remaining: Some(0),
1134        },
1135        model_request_usage: Vec::new(),
1136        edit_predictions: UsageCounts {
1137            used: 0,
1138            limit: Some(0),
1139            remaining: Some(0),
1140        },
1141    };
1142
1143    let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1144        return Ok(Json(empty_usage));
1145    };
1146
1147    let subscription_period = maybe!({
1148        let period_start_at = subscription.current_period_start_at()?;
1149        let period_end_at = subscription.current_period_end_at()?;
1150
1151        Some((period_start_at, period_end_at))
1152    });
1153
1154    let Some((period_start_at, period_end_at)) = subscription_period else {
1155        return Ok(Json(empty_usage));
1156    };
1157
1158    let usage = llm_db
1159        .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1160        .await?;
1161    let Some(usage) = usage else {
1162        return Ok(Json(empty_usage));
1163    };
1164
1165    let plan = match usage.plan {
1166        SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1167        SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1168        SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1169    };
1170
1171    let model_requests_limit = match plan.model_requests_limit() {
1172        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1173        zed_llm_client::UsageLimit::Unlimited => None,
1174    };
1175    let edit_prediction_limit = match plan.edit_predictions_limit() {
1176        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1177        zed_llm_client::UsageLimit::Unlimited => None,
1178    };
1179
1180    let subscription_usage_meters = llm_db
1181        .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1182        .await?;
1183
1184    let model_request_usage = subscription_usage_meters
1185        .into_iter()
1186        .filter_map(|(usage_meter, _usage)| {
1187            let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1188
1189            Some(ModelRequestUsage {
1190                model: model.name.clone(),
1191                mode: usage_meter.mode,
1192                requests: usage_meter.requests,
1193            })
1194        })
1195        .collect::<Vec<_>>();
1196
1197    Ok(Json(GetCurrentUsageResponse {
1198        model_requests: UsageCounts {
1199            used: usage.model_requests,
1200            limit: model_requests_limit,
1201            remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1202        },
1203        model_request_usage,
1204        edit_predictions: UsageCounts {
1205            used: usage.edit_predictions,
1206            limit: edit_prediction_limit,
1207            remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1208        },
1209    }))
1210}
1211
1212impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1213    fn from(value: SubscriptionStatus) -> Self {
1214        match value {
1215            SubscriptionStatus::Incomplete => Self::Incomplete,
1216            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1217            SubscriptionStatus::Trialing => Self::Trialing,
1218            SubscriptionStatus::Active => Self::Active,
1219            SubscriptionStatus::PastDue => Self::PastDue,
1220            SubscriptionStatus::Canceled => Self::Canceled,
1221            SubscriptionStatus::Unpaid => Self::Unpaid,
1222            SubscriptionStatus::Paused => Self::Paused,
1223        }
1224    }
1225}
1226
1227impl From<CancellationDetailsReason> for StripeCancellationReason {
1228    fn from(value: CancellationDetailsReason) -> Self {
1229        match value {
1230            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1231            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1232            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1233        }
1234    }
1235}
1236
1237/// Finds or creates a billing customer using the provided customer.
1238async fn find_or_create_billing_customer(
1239    app: &Arc<AppState>,
1240    stripe_client: &stripe::Client,
1241    customer_or_id: Expandable<Customer>,
1242) -> anyhow::Result<Option<billing_customer::Model>> {
1243    let customer_id = match &customer_or_id {
1244        Expandable::Id(id) => id,
1245        Expandable::Object(customer) => customer.id.as_ref(),
1246    };
1247
1248    // If we already have a billing customer record associated with the Stripe customer,
1249    // there's nothing more we need to do.
1250    if let Some(billing_customer) = app
1251        .db
1252        .get_billing_customer_by_stripe_customer_id(customer_id)
1253        .await?
1254    {
1255        return Ok(Some(billing_customer));
1256    }
1257
1258    // If all we have is a customer ID, resolve it to a full customer record by
1259    // hitting the Stripe API.
1260    let customer = match customer_or_id {
1261        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1262        Expandable::Object(customer) => *customer,
1263    };
1264
1265    let Some(email) = customer.email else {
1266        return Ok(None);
1267    };
1268
1269    let Some(user) = app.db.get_user_by_email(&email).await? else {
1270        return Ok(None);
1271    };
1272
1273    let billing_customer = app
1274        .db
1275        .create_billing_customer(&CreateBillingCustomerParams {
1276            user_id: user.id,
1277            stripe_customer_id: customer.id.to_string(),
1278        })
1279        .await?;
1280
1281    Ok(Some(billing_customer))
1282}
1283
1284const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1285
1286pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
1287    let Some(stripe_billing) = app.stripe_billing.clone() else {
1288        log::warn!("failed to retrieve Stripe billing object");
1289        return;
1290    };
1291    let Some(llm_db) = app.llm_db.clone() else {
1292        log::warn!("failed to retrieve LLM database");
1293        return;
1294    };
1295
1296    let executor = app.executor.clone();
1297    executor.spawn_detached({
1298        let executor = executor.clone();
1299        async move {
1300            loop {
1301                sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
1302                    .await
1303                    .context("failed to sync LLM usage to Stripe")
1304                    .trace_err();
1305                executor
1306                    .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
1307                    .await;
1308            }
1309        }
1310    });
1311}
1312
1313async fn sync_token_usage_with_stripe(
1314    app: &Arc<AppState>,
1315    llm_db: &Arc<LlmDatabase>,
1316    stripe_billing: &Arc<StripeBilling>,
1317) -> anyhow::Result<()> {
1318    let events = llm_db.get_billing_events().await?;
1319    let user_ids = events
1320        .iter()
1321        .map(|(event, _)| event.user_id)
1322        .collect::<HashSet<UserId>>();
1323    let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1324
1325    for (event, model) in events {
1326        let Some((stripe_db_customer, stripe_db_subscription)) =
1327            stripe_subscriptions.get(&event.user_id)
1328        else {
1329            tracing::warn!(
1330                user_id = event.user_id.0,
1331                "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1332            );
1333            continue;
1334        };
1335        let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1336            .stripe_subscription_id
1337            .parse()
1338            .context("failed to parse stripe subscription id from db")?;
1339        let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1340            .stripe_customer_id
1341            .parse()
1342            .context("failed to parse stripe customer id from db")?;
1343
1344        let stripe_model = stripe_billing
1345            .register_model_for_token_based_usage(&model)
1346            .await?;
1347        stripe_billing
1348            .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1349            .await?;
1350        stripe_billing
1351            .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
1352            .await?;
1353        llm_db.consume_billing_event(event.id).await?;
1354    }
1355
1356    Ok(())
1357}
1358
1359const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1360
1361pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1362    let Some(stripe_billing) = app.stripe_billing.clone() else {
1363        log::warn!("failed to retrieve Stripe billing object");
1364        return;
1365    };
1366    let Some(llm_db) = app.llm_db.clone() else {
1367        log::warn!("failed to retrieve LLM database");
1368        return;
1369    };
1370
1371    let executor = app.executor.clone();
1372    executor.spawn_detached({
1373        let executor = executor.clone();
1374        async move {
1375            loop {
1376                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1377                    .await
1378                    .context("failed to sync LLM request usage to Stripe")
1379                    .trace_err();
1380                executor
1381                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1382                    .await;
1383            }
1384        }
1385    });
1386}
1387
1388async fn sync_model_request_usage_with_stripe(
1389    app: &Arc<AppState>,
1390    llm_db: &Arc<LlmDatabase>,
1391    stripe_billing: &Arc<StripeBilling>,
1392) -> anyhow::Result<()> {
1393    let usage_meters = llm_db
1394        .get_current_subscription_usage_meters(Utc::now())
1395        .await?;
1396    let user_ids = usage_meters
1397        .iter()
1398        .map(|(_, usage)| usage.user_id)
1399        .collect::<HashSet<UserId>>();
1400    let billing_subscriptions = app
1401        .db
1402        .get_active_zed_pro_billing_subscriptions(user_ids)
1403        .await?;
1404
1405    let claude_3_5_sonnet = stripe_billing
1406        .find_price_id_by_lookup_key("claude-3-5-sonnet-requests")
1407        .await?;
1408    let claude_3_7_sonnet = stripe_billing
1409        .find_price_id_by_lookup_key("claude-3-7-sonnet-requests")
1410        .await?;
1411    let claude_3_7_sonnet_max = stripe_billing
1412        .find_price_id_by_lookup_key("claude-3-7-sonnet-requests-max")
1413        .await?;
1414
1415    for (usage_meter, usage) in usage_meters {
1416        maybe!(async {
1417            let Some((billing_customer, billing_subscription)) =
1418                billing_subscriptions.get(&usage.user_id)
1419            else {
1420                bail!(
1421                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1422                    usage.user_id
1423                );
1424            };
1425
1426            let stripe_customer_id = billing_customer
1427                .stripe_customer_id
1428                .parse::<stripe::CustomerId>()
1429                .context("failed to parse Stripe customer ID from database")?;
1430            let stripe_subscription_id = billing_subscription
1431                .stripe_subscription_id
1432                .parse::<stripe::SubscriptionId>()
1433                .context("failed to parse Stripe subscription ID from database")?;
1434
1435            let model = llm_db.model_by_id(usage_meter.model_id)?;
1436
1437            let (price_id, meter_event_name) = match model.name.as_str() {
1438                "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1439                "claude-3-7-sonnet" => match usage_meter.mode {
1440                    CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1441                    CompletionMode::Max => {
1442                        (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1443                    }
1444                },
1445                model_name => {
1446                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1447                }
1448            };
1449
1450            stripe_billing
1451                .subscribe_to_price(&stripe_subscription_id, price_id)
1452                .await?;
1453            stripe_billing
1454                .bill_model_request_usage(
1455                    &stripe_customer_id,
1456                    meter_event_name,
1457                    usage_meter.requests,
1458                )
1459                .await?;
1460
1461            Ok(())
1462        })
1463        .await
1464        .log_err();
1465    }
1466
1467    Ok(())
1468}