billing.rs

   1use anyhow::{Context, anyhow, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  19    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  20    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
  21    EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
  22};
  23use util::{ResultExt, maybe};
  24
  25use crate::api::events::SnowflakeRow;
  26use crate::db::billing_subscription::{
  27    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  28};
  29use crate::llm::db::subscription_usage_meter::CompletionMode;
  30use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
  31use crate::rpc::{ResultExt as _, Server};
  32use crate::{AppState, Cents, Error, Result};
  33use crate::{db::UserId, llm::db::LlmDatabase};
  34use crate::{
  35    db::{
  36        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  37        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  38        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  39    },
  40    stripe_billing::StripeBilling,
  41};
  42
  43pub fn router() -> Router {
  44    Router::new()
  45        .route(
  46            "/billing/preferences",
  47            get(get_billing_preferences).put(update_billing_preferences),
  48        )
  49        .route(
  50            "/billing/subscriptions",
  51            get(list_billing_subscriptions).post(create_billing_subscription),
  52        )
  53        .route(
  54            "/billing/subscriptions/manage",
  55            post(manage_billing_subscription),
  56        )
  57        .route("/billing/monthly_spend", get(get_monthly_spend))
  58        .route("/billing/usage", get(get_current_usage))
  59}
  60
  61#[derive(Debug, Deserialize)]
  62struct GetBillingPreferencesParams {
  63    github_user_id: i32,
  64}
  65
  66#[derive(Debug, Serialize)]
  67struct BillingPreferencesResponse {
  68    max_monthly_llm_usage_spending_in_cents: i32,
  69    model_request_overages_enabled: bool,
  70    model_request_overages_spend_limit_in_cents: i32,
  71}
  72
  73async fn get_billing_preferences(
  74    Extension(app): Extension<Arc<AppState>>,
  75    Query(params): Query<GetBillingPreferencesParams>,
  76) -> Result<Json<BillingPreferencesResponse>> {
  77    let user = app
  78        .db
  79        .get_user_by_github_user_id(params.github_user_id)
  80        .await?
  81        .ok_or_else(|| anyhow!("user not found"))?;
  82
  83    let preferences = app.db.get_billing_preferences(user.id).await?;
  84
  85    Ok(Json(BillingPreferencesResponse {
  86        max_monthly_llm_usage_spending_in_cents: preferences
  87            .as_ref()
  88            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
  89                preferences.max_monthly_llm_usage_spending_in_cents
  90            }),
  91        model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
  92            preferences.model_request_overages_enabled
  93        }),
  94        model_request_overages_spend_limit_in_cents: preferences
  95            .as_ref()
  96            .map_or(0, |preferences| {
  97                preferences.model_request_overages_spend_limit_in_cents
  98            }),
  99    }))
 100}
 101
 102#[derive(Debug, Deserialize)]
 103struct UpdateBillingPreferencesBody {
 104    github_user_id: i32,
 105    #[serde(default)]
 106    max_monthly_llm_usage_spending_in_cents: i32,
 107    #[serde(default)]
 108    model_request_overages_enabled: bool,
 109    #[serde(default)]
 110    model_request_overages_spend_limit_in_cents: i32,
 111}
 112
 113async fn update_billing_preferences(
 114    Extension(app): Extension<Arc<AppState>>,
 115    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
 116    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 117) -> Result<Json<BillingPreferencesResponse>> {
 118    let user = app
 119        .db
 120        .get_user_by_github_user_id(body.github_user_id)
 121        .await?
 122        .ok_or_else(|| anyhow!("user not found"))?;
 123
 124    let max_monthly_llm_usage_spending_in_cents =
 125        body.max_monthly_llm_usage_spending_in_cents.max(0);
 126    let model_request_overages_spend_limit_in_cents =
 127        body.model_request_overages_spend_limit_in_cents.max(0);
 128
 129    let billing_preferences =
 130        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 131            app.db
 132                .update_billing_preferences(
 133                    user.id,
 134                    &UpdateBillingPreferencesParams {
 135                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 136                            max_monthly_llm_usage_spending_in_cents,
 137                        ),
 138                        model_request_overages_enabled: ActiveValue::set(
 139                            body.model_request_overages_enabled,
 140                        ),
 141                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 142                            model_request_overages_spend_limit_in_cents,
 143                        ),
 144                    },
 145                )
 146                .await?
 147        } else {
 148            app.db
 149                .create_billing_preferences(
 150                    user.id,
 151                    &crate::db::CreateBillingPreferencesParams {
 152                        max_monthly_llm_usage_spending_in_cents,
 153                        model_request_overages_enabled: body.model_request_overages_enabled,
 154                        model_request_overages_spend_limit_in_cents,
 155                    },
 156                )
 157                .await?
 158        };
 159
 160    SnowflakeRow::new(
 161        "Billing Preferences Updated",
 162        Some(user.metrics_id),
 163        user.admin,
 164        None,
 165        json!({
 166            "user_id": user.id,
 167            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 168            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 169            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 170        }),
 171    )
 172    .write(&app.kinesis_client, &app.config.kinesis_stream)
 173    .await
 174    .log_err();
 175
 176    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 177
 178    Ok(Json(BillingPreferencesResponse {
 179        max_monthly_llm_usage_spending_in_cents: billing_preferences
 180            .max_monthly_llm_usage_spending_in_cents,
 181        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 182        model_request_overages_spend_limit_in_cents: billing_preferences
 183            .model_request_overages_spend_limit_in_cents,
 184    }))
 185}
 186
 187#[derive(Debug, Deserialize)]
 188struct ListBillingSubscriptionsParams {
 189    github_user_id: i32,
 190}
 191
 192#[derive(Debug, Serialize)]
 193struct BillingSubscriptionJson {
 194    id: BillingSubscriptionId,
 195    name: String,
 196    status: StripeSubscriptionStatus,
 197    trial_end_at: Option<String>,
 198    cancel_at: Option<String>,
 199    /// Whether this subscription can be canceled.
 200    is_cancelable: bool,
 201}
 202
 203#[derive(Debug, Serialize)]
 204struct ListBillingSubscriptionsResponse {
 205    subscriptions: Vec<BillingSubscriptionJson>,
 206}
 207
 208async fn list_billing_subscriptions(
 209    Extension(app): Extension<Arc<AppState>>,
 210    Query(params): Query<ListBillingSubscriptionsParams>,
 211) -> Result<Json<ListBillingSubscriptionsResponse>> {
 212    let user = app
 213        .db
 214        .get_user_by_github_user_id(params.github_user_id)
 215        .await?
 216        .ok_or_else(|| anyhow!("user not found"))?;
 217
 218    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 219
 220    Ok(Json(ListBillingSubscriptionsResponse {
 221        subscriptions: subscriptions
 222            .into_iter()
 223            .map(|subscription| BillingSubscriptionJson {
 224                id: subscription.id,
 225                name: match subscription.kind {
 226                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 227                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
 228                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
 229                    None => "Zed LLM Usage".to_string(),
 230                },
 231                status: subscription.stripe_subscription_status,
 232                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
 233                    maybe!({
 234                        let end_at = subscription.stripe_current_period_end?;
 235                        let end_at = DateTime::from_timestamp(end_at, 0)?;
 236
 237                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
 238                    })
 239                } else {
 240                    None
 241                },
 242                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 243                    cancel_at
 244                        .and_utc()
 245                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 246                }),
 247                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
 248                    && subscription.stripe_cancel_at.is_none(),
 249            })
 250            .collect(),
 251    }))
 252}
 253
 254#[derive(Debug, Clone, Copy, Deserialize)]
 255#[serde(rename_all = "snake_case")]
 256enum ProductCode {
 257    ZedPro,
 258    ZedProTrial,
 259}
 260
 261#[derive(Debug, Deserialize)]
 262struct CreateBillingSubscriptionBody {
 263    github_user_id: i32,
 264    product: Option<ProductCode>,
 265}
 266
 267#[derive(Debug, Serialize)]
 268struct CreateBillingSubscriptionResponse {
 269    checkout_session_url: String,
 270}
 271
 272/// Initiates a Stripe Checkout session for creating a billing subscription.
 273async fn create_billing_subscription(
 274    Extension(app): Extension<Arc<AppState>>,
 275    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 276) -> Result<Json<CreateBillingSubscriptionResponse>> {
 277    let user = app
 278        .db
 279        .get_user_by_github_user_id(body.github_user_id)
 280        .await?
 281        .ok_or_else(|| anyhow!("user not found"))?;
 282
 283    let Some(stripe_client) = app.stripe_client.clone() else {
 284        log::error!("failed to retrieve Stripe client");
 285        Err(Error::http(
 286            StatusCode::NOT_IMPLEMENTED,
 287            "not supported".into(),
 288        ))?
 289    };
 290    let Some(stripe_billing) = app.stripe_billing.clone() else {
 291        log::error!("failed to retrieve Stripe billing object");
 292        Err(Error::http(
 293            StatusCode::NOT_IMPLEMENTED,
 294            "not supported".into(),
 295        ))?
 296    };
 297    let Some(llm_db) = app.llm_db.clone() else {
 298        log::error!("failed to retrieve LLM database");
 299        Err(Error::http(
 300            StatusCode::NOT_IMPLEMENTED,
 301            "not supported".into(),
 302        ))?
 303    };
 304
 305    if app.db.has_active_billing_subscription(user.id).await? {
 306        return Err(Error::http(
 307            StatusCode::CONFLICT,
 308            "user already has an active subscription".into(),
 309        ));
 310    }
 311
 312    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 313    if let Some(existing_billing_customer) = &existing_billing_customer {
 314        if existing_billing_customer.has_overdue_invoices {
 315            return Err(Error::http(
 316                StatusCode::PAYMENT_REQUIRED,
 317                "user has overdue invoices".into(),
 318            ));
 319        }
 320    }
 321
 322    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 323        CustomerId::from_str(&existing_customer.stripe_customer_id)
 324            .context("failed to parse customer ID")?
 325    } else {
 326        let existing_customer = if let Some(email) = user.email_address.as_deref() {
 327            let customers = Customer::list(
 328                &stripe_client,
 329                &stripe::ListCustomers {
 330                    email: Some(email),
 331                    ..Default::default()
 332                },
 333            )
 334            .await?;
 335
 336            customers.data.first().cloned()
 337        } else {
 338            None
 339        };
 340
 341        if let Some(existing_customer) = existing_customer {
 342            existing_customer.id
 343        } else {
 344            let customer = Customer::create(
 345                &stripe_client,
 346                CreateCustomer {
 347                    email: user.email_address.as_deref(),
 348                    ..Default::default()
 349                },
 350            )
 351            .await?;
 352
 353            customer.id
 354        }
 355    };
 356
 357    let success_url = format!(
 358        "{}/account?checkout_complete=1",
 359        app.config.zed_dot_dev_url()
 360    );
 361
 362    let checkout_session_url = match body.product {
 363        Some(ProductCode::ZedPro) => {
 364            stripe_billing
 365                .checkout_with_price(
 366                    app.config.zed_pro_price_id()?,
 367                    customer_id,
 368                    &user.github_login,
 369                    &success_url,
 370                )
 371                .await?
 372        }
 373        Some(ProductCode::ZedProTrial) => {
 374            if let Some(existing_billing_customer) = &existing_billing_customer {
 375                if existing_billing_customer.trial_started_at.is_some() {
 376                    return Err(Error::http(
 377                        StatusCode::FORBIDDEN,
 378                        "user already used free trial".into(),
 379                    ));
 380                }
 381            }
 382
 383            stripe_billing
 384                .checkout_with_zed_pro_trial(
 385                    app.config.zed_pro_price_id()?,
 386                    customer_id,
 387                    &user.github_login,
 388                    &success_url,
 389                )
 390                .await?
 391        }
 392        None => {
 393            let default_model = llm_db.model(
 394                zed_llm_client::LanguageModelProvider::Anthropic,
 395                "claude-3-7-sonnet",
 396            )?;
 397            let stripe_model = stripe_billing
 398                .register_model_for_token_based_usage(default_model)
 399                .await?;
 400            stripe_billing
 401                .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
 402                .await?
 403        }
 404    };
 405
 406    Ok(Json(CreateBillingSubscriptionResponse {
 407        checkout_session_url,
 408    }))
 409}
 410
 411#[derive(Debug, PartialEq, Deserialize)]
 412#[serde(rename_all = "snake_case")]
 413enum ManageSubscriptionIntent {
 414    /// The user intends to manage their subscription.
 415    ///
 416    /// This will open the Stripe billing portal without putting the user in a specific flow.
 417    ManageSubscription,
 418    /// The user intends to upgrade to Zed Pro.
 419    UpgradeToPro,
 420    /// The user intends to cancel their subscription.
 421    Cancel,
 422    /// The user intends to stop the cancellation of their subscription.
 423    StopCancellation,
 424}
 425
 426#[derive(Debug, Deserialize)]
 427struct ManageBillingSubscriptionBody {
 428    github_user_id: i32,
 429    intent: ManageSubscriptionIntent,
 430    /// The ID of the subscription to manage.
 431    subscription_id: BillingSubscriptionId,
 432}
 433
 434#[derive(Debug, Serialize)]
 435struct ManageBillingSubscriptionResponse {
 436    billing_portal_session_url: Option<String>,
 437}
 438
 439/// Initiates a Stripe customer portal session for managing a billing subscription.
 440async fn manage_billing_subscription(
 441    Extension(app): Extension<Arc<AppState>>,
 442    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 443) -> Result<Json<ManageBillingSubscriptionResponse>> {
 444    let user = app
 445        .db
 446        .get_user_by_github_user_id(body.github_user_id)
 447        .await?
 448        .ok_or_else(|| anyhow!("user not found"))?;
 449
 450    let Some(stripe_client) = app.stripe_client.clone() else {
 451        log::error!("failed to retrieve Stripe client");
 452        Err(Error::http(
 453            StatusCode::NOT_IMPLEMENTED,
 454            "not supported".into(),
 455        ))?
 456    };
 457
 458    let customer = app
 459        .db
 460        .get_billing_customer_by_user_id(user.id)
 461        .await?
 462        .ok_or_else(|| anyhow!("billing customer not found"))?;
 463    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 464        .context("failed to parse customer ID")?;
 465
 466    let subscription = app
 467        .db
 468        .get_billing_subscription_by_id(body.subscription_id)
 469        .await?
 470        .ok_or_else(|| anyhow!("subscription not found"))?;
 471    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 472        .context("failed to parse subscription ID")?;
 473
 474    if body.intent == ManageSubscriptionIntent::StopCancellation {
 475        let updated_stripe_subscription = Subscription::update(
 476            &stripe_client,
 477            &subscription_id,
 478            stripe::UpdateSubscription {
 479                cancel_at_period_end: Some(false),
 480                ..Default::default()
 481            },
 482        )
 483        .await?;
 484
 485        app.db
 486            .update_billing_subscription(
 487                subscription.id,
 488                &UpdateBillingSubscriptionParams {
 489                    stripe_cancel_at: ActiveValue::set(
 490                        updated_stripe_subscription
 491                            .cancel_at
 492                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 493                            .map(|time| time.naive_utc()),
 494                    ),
 495                    ..Default::default()
 496                },
 497            )
 498            .await?;
 499
 500        return Ok(Json(ManageBillingSubscriptionResponse {
 501            billing_portal_session_url: None,
 502        }));
 503    }
 504
 505    let flow = match body.intent {
 506        ManageSubscriptionIntent::ManageSubscription => None,
 507        ManageSubscriptionIntent::UpgradeToPro => {
 508            let zed_pro_price_id = app.config.zed_pro_price_id()?;
 509            let zed_free_price_id = app.config.zed_free_price_id()?;
 510
 511            let stripe_subscription =
 512                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 513
 514            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 515                && stripe_subscription.items.data.iter().any(|item| {
 516                    item.price
 517                        .as_ref()
 518                        .map_or(false, |price| price.id == zed_pro_price_id)
 519                });
 520            if is_on_zed_pro_trial {
 521                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 522                Subscription::update(
 523                    &stripe_client,
 524                    &stripe_subscription.id,
 525                    stripe::UpdateSubscription {
 526                        trial_end: Some(stripe::Scheduled::now()),
 527                        ..Default::default()
 528                    },
 529                )
 530                .await?;
 531
 532                return Ok(Json(ManageBillingSubscriptionResponse {
 533                    billing_portal_session_url: None,
 534                }));
 535            }
 536
 537            let subscription_item_to_update = stripe_subscription
 538                .items
 539                .data
 540                .iter()
 541                .find_map(|item| {
 542                    let price = item.price.as_ref()?;
 543
 544                    if price.id == zed_free_price_id {
 545                        Some(item.id.clone())
 546                    } else {
 547                        None
 548                    }
 549                })
 550                .ok_or_else(|| anyhow!("No subscription item to update"))?;
 551
 552            Some(CreateBillingPortalSessionFlowData {
 553                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 554                subscription_update_confirm: Some(
 555                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 556                        subscription: subscription.stripe_subscription_id,
 557                        items: vec![
 558                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 559                                id: subscription_item_to_update.to_string(),
 560                                price: Some(zed_pro_price_id.to_string()),
 561                                quantity: Some(1),
 562                            },
 563                        ],
 564                        discounts: None,
 565                    },
 566                ),
 567                ..Default::default()
 568            })
 569        }
 570        ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
 571            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 572            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 573                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 574                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 575                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 576                }),
 577                ..Default::default()
 578            }),
 579            subscription_cancel: Some(
 580                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 581                    subscription: subscription.stripe_subscription_id,
 582                    retention: None,
 583                },
 584            ),
 585            ..Default::default()
 586        }),
 587        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 588    };
 589
 590    let mut params = CreateBillingPortalSession::new(customer_id);
 591    params.flow_data = flow;
 592    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 593    params.return_url = Some(&return_url);
 594
 595    let session = BillingPortalSession::create(&stripe_client, params).await?;
 596
 597    Ok(Json(ManageBillingSubscriptionResponse {
 598        billing_portal_session_url: Some(session.url),
 599    }))
 600}
 601
 602/// The amount of time we wait in between each poll of Stripe events.
 603///
 604/// This value should strike a balance between:
 605///   1. Being short enough that we update quickly when something in Stripe changes
 606///   2. Being long enough that we don't eat into our rate limits.
 607///
 608/// As a point of reference, the Sequin folks say they have this at **500ms**:
 609///
 610/// > We poll the Stripe /events endpoint every 500ms per account
 611/// >
 612/// > — https://blog.sequinstream.com/events-not-webhooks/
 613const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 614
 615/// The maximum number of events to return per page.
 616///
 617/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 618///
 619/// > Limit can range between 1 and 100, and the default is 10.
 620const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 621
 622/// The number of pages consisting entirely of already-processed events that we
 623/// will see before we stop retrieving events.
 624///
 625/// This is used to prevent over-fetching the Stripe events API for events we've
 626/// already seen and processed.
 627const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 628
 629/// Polls the Stripe events API periodically to reconcile the records in our
 630/// database with the data in Stripe.
 631pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 632    let Some(stripe_client) = app.stripe_client.clone() else {
 633        log::warn!("failed to retrieve Stripe client");
 634        return;
 635    };
 636
 637    let executor = app.executor.clone();
 638    executor.spawn_detached({
 639        let executor = executor.clone();
 640        async move {
 641            loop {
 642                poll_stripe_events(&app, &rpc_server, &stripe_client)
 643                    .await
 644                    .log_err();
 645
 646                executor.sleep(POLL_EVENTS_INTERVAL).await;
 647            }
 648        }
 649    });
 650}
 651
 652async fn poll_stripe_events(
 653    app: &Arc<AppState>,
 654    rpc_server: &Arc<Server>,
 655    stripe_client: &stripe::Client,
 656) -> anyhow::Result<()> {
 657    fn event_type_to_string(event_type: EventType) -> String {
 658        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 659        // so we need to unquote it.
 660        event_type.to_string().trim_matches('"').to_string()
 661    }
 662
 663    let event_types = [
 664        EventType::CustomerCreated,
 665        EventType::CustomerUpdated,
 666        EventType::CustomerSubscriptionCreated,
 667        EventType::CustomerSubscriptionUpdated,
 668        EventType::CustomerSubscriptionPaused,
 669        EventType::CustomerSubscriptionResumed,
 670        EventType::CustomerSubscriptionDeleted,
 671    ]
 672    .into_iter()
 673    .map(event_type_to_string)
 674    .collect::<Vec<_>>();
 675
 676    let mut pages_of_already_processed_events = 0;
 677    let mut unprocessed_events = Vec::new();
 678
 679    log::info!(
 680        "Stripe events: starting retrieval for {}",
 681        event_types.join(", ")
 682    );
 683    let mut params = ListEvents::new();
 684    params.types = Some(event_types.clone());
 685    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 686
 687    let mut event_pages = stripe::Event::list(&stripe_client, &params)
 688        .await?
 689        .paginate(params);
 690
 691    loop {
 692        let processed_event_ids = {
 693            let event_ids = event_pages
 694                .page
 695                .data
 696                .iter()
 697                .map(|event| event.id.as_str())
 698                .collect::<Vec<_>>();
 699            app.db
 700                .get_processed_stripe_events_by_event_ids(&event_ids)
 701                .await?
 702                .into_iter()
 703                .map(|event| event.stripe_event_id)
 704                .collect::<Vec<_>>()
 705        };
 706
 707        let mut processed_events_in_page = 0;
 708        let events_in_page = event_pages.page.data.len();
 709        for event in &event_pages.page.data {
 710            if processed_event_ids.contains(&event.id.to_string()) {
 711                processed_events_in_page += 1;
 712                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 713            } else {
 714                unprocessed_events.push(event.clone());
 715            }
 716        }
 717
 718        if processed_events_in_page == events_in_page {
 719            pages_of_already_processed_events += 1;
 720        }
 721
 722        if event_pages.page.has_more {
 723            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 724            {
 725                log::info!(
 726                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 727                );
 728                break;
 729            } else {
 730                log::info!("Stripe events: retrieving next page");
 731                event_pages = event_pages.next(&stripe_client).await?;
 732            }
 733        } else {
 734            break;
 735        }
 736    }
 737
 738    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 739
 740    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 741    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 742
 743    for event in unprocessed_events {
 744        let event_id = event.id.clone();
 745        let processed_event_params = CreateProcessedStripeEventParams {
 746            stripe_event_id: event.id.to_string(),
 747            stripe_event_type: event_type_to_string(event.type_),
 748            stripe_event_created_timestamp: event.created,
 749        };
 750
 751        // If the event has happened too far in the past, we don't want to
 752        // process it and risk overwriting other more-recent updates.
 753        //
 754        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 755        let one_day = Duration::from_secs(24 * 60 * 60);
 756        let a_day_ago = Utc::now() - one_day;
 757        if a_day_ago.timestamp() > event.created {
 758            log::info!(
 759                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 760                event_id
 761            );
 762            app.db
 763                .create_processed_stripe_event(&processed_event_params)
 764                .await?;
 765
 766            return Ok(());
 767        }
 768
 769        let process_result = match event.type_ {
 770            EventType::CustomerCreated | EventType::CustomerUpdated => {
 771                handle_customer_event(app, stripe_client, event).await
 772            }
 773            EventType::CustomerSubscriptionCreated
 774            | EventType::CustomerSubscriptionUpdated
 775            | EventType::CustomerSubscriptionPaused
 776            | EventType::CustomerSubscriptionResumed
 777            | EventType::CustomerSubscriptionDeleted => {
 778                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 779            }
 780            _ => Ok(()),
 781        };
 782
 783        if let Some(()) = process_result
 784            .with_context(|| format!("failed to process event {event_id} successfully"))
 785            .log_err()
 786        {
 787            app.db
 788                .create_processed_stripe_event(&processed_event_params)
 789                .await?;
 790        }
 791    }
 792
 793    Ok(())
 794}
 795
 796async fn handle_customer_event(
 797    app: &Arc<AppState>,
 798    _stripe_client: &stripe::Client,
 799    event: stripe::Event,
 800) -> anyhow::Result<()> {
 801    let EventObject::Customer(customer) = event.data.object else {
 802        bail!("unexpected event payload for {}", event.id);
 803    };
 804
 805    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 806
 807    let Some(email) = customer.email else {
 808        log::info!("Stripe customer has no email: skipping");
 809        return Ok(());
 810    };
 811
 812    let Some(user) = app.db.get_user_by_email(&email).await? else {
 813        log::info!("no user found for email: skipping");
 814        return Ok(());
 815    };
 816
 817    if let Some(existing_customer) = app
 818        .db
 819        .get_billing_customer_by_stripe_customer_id(&customer.id)
 820        .await?
 821    {
 822        app.db
 823            .update_billing_customer(
 824                existing_customer.id,
 825                &UpdateBillingCustomerParams {
 826                    // For now we just leave the information as-is, as it is not
 827                    // likely to change.
 828                    ..Default::default()
 829                },
 830            )
 831            .await?;
 832    } else {
 833        app.db
 834            .create_billing_customer(&CreateBillingCustomerParams {
 835                user_id: user.id,
 836                stripe_customer_id: customer.id.to_string(),
 837            })
 838            .await?;
 839    }
 840
 841    Ok(())
 842}
 843
 844async fn handle_customer_subscription_event(
 845    app: &Arc<AppState>,
 846    rpc_server: &Arc<Server>,
 847    stripe_client: &stripe::Client,
 848    event: stripe::Event,
 849) -> anyhow::Result<()> {
 850    let EventObject::Subscription(subscription) = event.data.object else {
 851        bail!("unexpected event payload for {}", event.id);
 852    };
 853
 854    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 855
 856    let subscription_kind = maybe!({
 857        let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
 858        let zed_free_price_id = app.config.zed_free_price_id().ok()?;
 859
 860        subscription.items.data.iter().find_map(|item| {
 861            let price = item.price.as_ref()?;
 862
 863            if price.id == zed_pro_price_id {
 864                Some(if subscription.status == SubscriptionStatus::Trialing {
 865                    SubscriptionKind::ZedProTrial
 866                } else {
 867                    SubscriptionKind::ZedPro
 868                })
 869            } else if price.id == zed_free_price_id {
 870                Some(SubscriptionKind::ZedFree)
 871            } else {
 872                None
 873            }
 874        })
 875    });
 876
 877    let billing_customer =
 878        find_or_create_billing_customer(app, stripe_client, subscription.customer)
 879            .await?
 880            .ok_or_else(|| anyhow!("billing customer not found"))?;
 881
 882    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
 883        if subscription.status == SubscriptionStatus::Trialing {
 884            let current_period_start =
 885                DateTime::from_timestamp(subscription.current_period_start, 0)
 886                    .ok_or_else(|| anyhow!("No trial subscription period start"))?;
 887
 888            app.db
 889                .update_billing_customer(
 890                    billing_customer.id,
 891                    &UpdateBillingCustomerParams {
 892                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
 893                        ..Default::default()
 894                    },
 895                )
 896                .await?;
 897        }
 898    }
 899
 900    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 901        && subscription
 902            .cancellation_details
 903            .as_ref()
 904            .and_then(|details| details.reason)
 905            .map_or(false, |reason| {
 906                reason == CancellationDetailsReason::PaymentFailed
 907            });
 908
 909    if was_canceled_due_to_payment_failure {
 910        app.db
 911            .update_billing_customer(
 912                billing_customer.id,
 913                &UpdateBillingCustomerParams {
 914                    has_overdue_invoices: ActiveValue::set(true),
 915                    ..Default::default()
 916                },
 917            )
 918            .await?;
 919    }
 920
 921    if let Some(existing_subscription) = app
 922        .db
 923        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
 924        .await?
 925    {
 926        let llm_db = app
 927            .llm_db
 928            .clone()
 929            .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
 930
 931        let new_period_start_at =
 932            chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
 933                .ok_or_else(|| anyhow!("No subscription period start"))?;
 934        let new_period_end_at =
 935            chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
 936                .ok_or_else(|| anyhow!("No subscription period end"))?;
 937
 938        llm_db
 939            .transfer_existing_subscription_usage(
 940                billing_customer.user_id,
 941                &existing_subscription,
 942                subscription_kind,
 943                new_period_start_at,
 944                new_period_end_at,
 945            )
 946            .await?;
 947
 948        app.db
 949            .update_billing_subscription(
 950                existing_subscription.id,
 951                &UpdateBillingSubscriptionParams {
 952                    billing_customer_id: ActiveValue::set(billing_customer.id),
 953                    kind: ActiveValue::set(subscription_kind),
 954                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
 955                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
 956                    stripe_cancel_at: ActiveValue::set(
 957                        subscription
 958                            .cancel_at
 959                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 960                            .map(|time| time.naive_utc()),
 961                    ),
 962                    stripe_cancellation_reason: ActiveValue::set(
 963                        subscription
 964                            .cancellation_details
 965                            .and_then(|details| details.reason)
 966                            .map(|reason| reason.into()),
 967                    ),
 968                    stripe_current_period_start: ActiveValue::set(Some(
 969                        subscription.current_period_start,
 970                    )),
 971                    stripe_current_period_end: ActiveValue::set(Some(
 972                        subscription.current_period_end,
 973                    )),
 974                },
 975            )
 976            .await?;
 977    } else {
 978        // If the user already has an active billing subscription, ignore the
 979        // event and return an `Ok` to signal that it was processed
 980        // successfully.
 981        //
 982        // There is the possibility that this could cause us to not create a
 983        // subscription in the following scenario:
 984        //
 985        //   1. User has an active subscription A
 986        //   2. User cancels subscription A
 987        //   3. User creates a new subscription B
 988        //   4. We process the new subscription B before the cancellation of subscription A
 989        //   5. User ends up with no subscriptions
 990        //
 991        // In theory this situation shouldn't arise as we try to process the events in the order they occur.
 992        if app
 993            .db
 994            .has_active_billing_subscription(billing_customer.user_id)
 995            .await?
 996        {
 997            log::info!(
 998                "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
 999                user_id = billing_customer.user_id,
1000                subscription_id = subscription.id
1001            );
1002            return Ok(());
1003        }
1004
1005        app.db
1006            .create_billing_subscription(&CreateBillingSubscriptionParams {
1007                billing_customer_id: billing_customer.id,
1008                kind: subscription_kind,
1009                stripe_subscription_id: subscription.id.to_string(),
1010                stripe_subscription_status: subscription.status.into(),
1011                stripe_cancellation_reason: subscription
1012                    .cancellation_details
1013                    .and_then(|details| details.reason)
1014                    .map(|reason| reason.into()),
1015                stripe_current_period_start: Some(subscription.current_period_start),
1016                stripe_current_period_end: Some(subscription.current_period_end),
1017            })
1018            .await?;
1019    }
1020
1021    // When the user's subscription changes, we want to refresh their LLM tokens
1022    // to either grant/revoke access.
1023    rpc_server
1024        .refresh_llm_tokens_for_user(billing_customer.user_id)
1025        .await;
1026
1027    Ok(())
1028}
1029
1030#[derive(Debug, Deserialize)]
1031struct GetMonthlySpendParams {
1032    github_user_id: i32,
1033}
1034
1035#[derive(Debug, Serialize)]
1036struct GetMonthlySpendResponse {
1037    monthly_free_tier_spend_in_cents: u32,
1038    monthly_free_tier_allowance_in_cents: u32,
1039    monthly_spend_in_cents: u32,
1040}
1041
1042async fn get_monthly_spend(
1043    Extension(app): Extension<Arc<AppState>>,
1044    Query(params): Query<GetMonthlySpendParams>,
1045) -> Result<Json<GetMonthlySpendResponse>> {
1046    let user = app
1047        .db
1048        .get_user_by_github_user_id(params.github_user_id)
1049        .await?
1050        .ok_or_else(|| anyhow!("user not found"))?;
1051
1052    let Some(llm_db) = app.llm_db.clone() else {
1053        return Err(Error::http(
1054            StatusCode::NOT_IMPLEMENTED,
1055            "LLM database not available".into(),
1056        ));
1057    };
1058
1059    let free_tier = user
1060        .custom_llm_monthly_allowance_in_cents
1061        .map(|allowance| Cents(allowance as u32))
1062        .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1063
1064    let spending_for_month = llm_db
1065        .get_user_spending_for_month(user.id, Utc::now())
1066        .await?;
1067
1068    let free_tier_spend = Cents::min(spending_for_month, free_tier);
1069    let monthly_spend = spending_for_month.saturating_sub(free_tier);
1070
1071    Ok(Json(GetMonthlySpendResponse {
1072        monthly_free_tier_spend_in_cents: free_tier_spend.0,
1073        monthly_free_tier_allowance_in_cents: free_tier.0,
1074        monthly_spend_in_cents: monthly_spend.0,
1075    }))
1076}
1077
1078#[derive(Debug, Deserialize)]
1079struct GetCurrentUsageParams {
1080    github_user_id: i32,
1081}
1082
1083#[derive(Debug, Serialize)]
1084struct UsageCounts {
1085    pub used: i32,
1086    pub limit: Option<i32>,
1087    pub remaining: Option<i32>,
1088}
1089
1090#[derive(Debug, Serialize)]
1091struct GetCurrentUsageResponse {
1092    pub model_requests: UsageCounts,
1093    pub edit_predictions: UsageCounts,
1094}
1095
1096async fn get_current_usage(
1097    Extension(app): Extension<Arc<AppState>>,
1098    Query(params): Query<GetCurrentUsageParams>,
1099) -> Result<Json<GetCurrentUsageResponse>> {
1100    let user = app
1101        .db
1102        .get_user_by_github_user_id(params.github_user_id)
1103        .await?
1104        .ok_or_else(|| anyhow!("user not found"))?;
1105
1106    let Some(llm_db) = app.llm_db.clone() else {
1107        return Err(Error::http(
1108            StatusCode::NOT_IMPLEMENTED,
1109            "LLM database not available".into(),
1110        ));
1111    };
1112
1113    let empty_usage = GetCurrentUsageResponse {
1114        model_requests: UsageCounts {
1115            used: 0,
1116            limit: Some(0),
1117            remaining: Some(0),
1118        },
1119        edit_predictions: UsageCounts {
1120            used: 0,
1121            limit: Some(0),
1122            remaining: Some(0),
1123        },
1124    };
1125
1126    let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1127        return Ok(Json(empty_usage));
1128    };
1129
1130    let subscription_period = maybe!({
1131        let period_start_at = subscription.current_period_start_at()?;
1132        let period_end_at = subscription.current_period_end_at()?;
1133
1134        Some((period_start_at, period_end_at))
1135    });
1136
1137    let Some((period_start_at, period_end_at)) = subscription_period else {
1138        return Ok(Json(empty_usage));
1139    };
1140
1141    let usage = llm_db
1142        .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1143        .await?;
1144    let Some(usage) = usage else {
1145        return Ok(Json(empty_usage));
1146    };
1147
1148    let plan = match usage.plan {
1149        SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1150        SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1151        SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1152    };
1153
1154    let model_requests_limit = match plan.model_requests_limit() {
1155        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1156        zed_llm_client::UsageLimit::Unlimited => None,
1157    };
1158    let edit_prediction_limit = match plan.edit_predictions_limit() {
1159        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1160        zed_llm_client::UsageLimit::Unlimited => None,
1161    };
1162
1163    Ok(Json(GetCurrentUsageResponse {
1164        model_requests: UsageCounts {
1165            used: usage.model_requests,
1166            limit: model_requests_limit,
1167            remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1168        },
1169        edit_predictions: UsageCounts {
1170            used: usage.edit_predictions,
1171            limit: edit_prediction_limit,
1172            remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1173        },
1174    }))
1175}
1176
1177impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1178    fn from(value: SubscriptionStatus) -> Self {
1179        match value {
1180            SubscriptionStatus::Incomplete => Self::Incomplete,
1181            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1182            SubscriptionStatus::Trialing => Self::Trialing,
1183            SubscriptionStatus::Active => Self::Active,
1184            SubscriptionStatus::PastDue => Self::PastDue,
1185            SubscriptionStatus::Canceled => Self::Canceled,
1186            SubscriptionStatus::Unpaid => Self::Unpaid,
1187            SubscriptionStatus::Paused => Self::Paused,
1188        }
1189    }
1190}
1191
1192impl From<CancellationDetailsReason> for StripeCancellationReason {
1193    fn from(value: CancellationDetailsReason) -> Self {
1194        match value {
1195            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1196            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1197            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1198        }
1199    }
1200}
1201
1202/// Finds or creates a billing customer using the provided customer.
1203async fn find_or_create_billing_customer(
1204    app: &Arc<AppState>,
1205    stripe_client: &stripe::Client,
1206    customer_or_id: Expandable<Customer>,
1207) -> anyhow::Result<Option<billing_customer::Model>> {
1208    let customer_id = match &customer_or_id {
1209        Expandable::Id(id) => id,
1210        Expandable::Object(customer) => customer.id.as_ref(),
1211    };
1212
1213    // If we already have a billing customer record associated with the Stripe customer,
1214    // there's nothing more we need to do.
1215    if let Some(billing_customer) = app
1216        .db
1217        .get_billing_customer_by_stripe_customer_id(customer_id)
1218        .await?
1219    {
1220        return Ok(Some(billing_customer));
1221    }
1222
1223    // If all we have is a customer ID, resolve it to a full customer record by
1224    // hitting the Stripe API.
1225    let customer = match customer_or_id {
1226        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1227        Expandable::Object(customer) => *customer,
1228    };
1229
1230    let Some(email) = customer.email else {
1231        return Ok(None);
1232    };
1233
1234    let Some(user) = app.db.get_user_by_email(&email).await? else {
1235        return Ok(None);
1236    };
1237
1238    let billing_customer = app
1239        .db
1240        .create_billing_customer(&CreateBillingCustomerParams {
1241            user_id: user.id,
1242            stripe_customer_id: customer.id.to_string(),
1243        })
1244        .await?;
1245
1246    Ok(Some(billing_customer))
1247}
1248
1249const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1250
1251pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
1252    let Some(stripe_billing) = app.stripe_billing.clone() else {
1253        log::warn!("failed to retrieve Stripe billing object");
1254        return;
1255    };
1256    let Some(llm_db) = app.llm_db.clone() else {
1257        log::warn!("failed to retrieve LLM database");
1258        return;
1259    };
1260
1261    let executor = app.executor.clone();
1262    executor.spawn_detached({
1263        let executor = executor.clone();
1264        async move {
1265            loop {
1266                sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
1267                    .await
1268                    .context("failed to sync LLM usage to Stripe")
1269                    .trace_err();
1270                executor
1271                    .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
1272                    .await;
1273            }
1274        }
1275    });
1276}
1277
1278async fn sync_token_usage_with_stripe(
1279    app: &Arc<AppState>,
1280    llm_db: &Arc<LlmDatabase>,
1281    stripe_billing: &Arc<StripeBilling>,
1282) -> anyhow::Result<()> {
1283    let events = llm_db.get_billing_events().await?;
1284    let user_ids = events
1285        .iter()
1286        .map(|(event, _)| event.user_id)
1287        .collect::<HashSet<UserId>>();
1288    let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1289
1290    for (event, model) in events {
1291        let Some((stripe_db_customer, stripe_db_subscription)) =
1292            stripe_subscriptions.get(&event.user_id)
1293        else {
1294            tracing::warn!(
1295                user_id = event.user_id.0,
1296                "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1297            );
1298            continue;
1299        };
1300        let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1301            .stripe_subscription_id
1302            .parse()
1303            .context("failed to parse stripe subscription id from db")?;
1304        let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1305            .stripe_customer_id
1306            .parse()
1307            .context("failed to parse stripe customer id from db")?;
1308
1309        let stripe_model = stripe_billing
1310            .register_model_for_token_based_usage(&model)
1311            .await?;
1312        stripe_billing
1313            .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1314            .await?;
1315        stripe_billing
1316            .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
1317            .await?;
1318        llm_db.consume_billing_event(event.id).await?;
1319    }
1320
1321    Ok(())
1322}
1323
1324const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1325
1326pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1327    let Some(stripe_billing) = app.stripe_billing.clone() else {
1328        log::warn!("failed to retrieve Stripe billing object");
1329        return;
1330    };
1331    let Some(llm_db) = app.llm_db.clone() else {
1332        log::warn!("failed to retrieve LLM database");
1333        return;
1334    };
1335
1336    let executor = app.executor.clone();
1337    executor.spawn_detached({
1338        let executor = executor.clone();
1339        async move {
1340            loop {
1341                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1342                    .await
1343                    .context("failed to sync LLM request usage to Stripe")
1344                    .trace_err();
1345                executor
1346                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1347                    .await;
1348            }
1349        }
1350    });
1351}
1352
1353async fn sync_model_request_usage_with_stripe(
1354    app: &Arc<AppState>,
1355    llm_db: &Arc<LlmDatabase>,
1356    stripe_billing: &Arc<StripeBilling>,
1357) -> anyhow::Result<()> {
1358    let usage_meters = llm_db
1359        .get_current_subscription_usage_meters(Utc::now())
1360        .await?;
1361    let user_ids = usage_meters
1362        .iter()
1363        .map(|(_, usage)| usage.user_id)
1364        .collect::<HashSet<UserId>>();
1365    let billing_subscriptions = app
1366        .db
1367        .get_active_zed_pro_billing_subscriptions(user_ids)
1368        .await?;
1369
1370    let claude_3_5_sonnet = stripe_billing
1371        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1372        .await?;
1373    let claude_3_7_sonnet = stripe_billing
1374        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1375        .await?;
1376    let claude_3_7_sonnet_max = stripe_billing
1377        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1378        .await?;
1379
1380    for (usage_meter, usage) in usage_meters {
1381        maybe!(async {
1382            let Some((billing_customer, billing_subscription)) =
1383                billing_subscriptions.get(&usage.user_id)
1384            else {
1385                bail!(
1386                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1387                    usage.user_id
1388                );
1389            };
1390
1391            let stripe_customer_id = billing_customer
1392                .stripe_customer_id
1393                .parse::<stripe::CustomerId>()
1394                .context("failed to parse Stripe customer ID from database")?;
1395            let stripe_subscription_id = billing_subscription
1396                .stripe_subscription_id
1397                .parse::<stripe::SubscriptionId>()
1398                .context("failed to parse Stripe subscription ID from database")?;
1399
1400            let model = llm_db.model_by_id(usage_meter.model_id)?;
1401
1402            let (price_id, meter_event_name) = match model.name.as_str() {
1403                "claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
1404                "claude-3-7-sonnet" => match usage_meter.mode {
1405                    CompletionMode::Normal => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
1406                    CompletionMode::Max => {
1407                        (&claude_3_7_sonnet_max.id, "claude_3_7_sonnet/requests/max")
1408                    }
1409                },
1410                model_name => {
1411                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1412                }
1413            };
1414
1415            stripe_billing
1416                .subscribe_to_price(&stripe_subscription_id, price_id)
1417                .await?;
1418            stripe_billing
1419                .bill_model_request_usage(
1420                    &stripe_customer_id,
1421                    meter_event_name,
1422                    usage_meter.requests,
1423                )
1424                .await?;
1425
1426            Ok(())
1427        })
1428        .await
1429        .log_err();
1430    }
1431
1432    Ok(())
1433}