billing.rs

   1use anyhow::{Context, anyhow, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  19    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  20    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
  21    EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
  22};
  23use util::{ResultExt, maybe};
  24
  25use crate::api::events::SnowflakeRow;
  26use crate::db::billing_subscription::{
  27    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  28};
  29use crate::llm::db::subscription_usage_meter::CompletionMode;
  30use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
  31use crate::rpc::{ResultExt as _, Server};
  32use crate::{AppState, Cents, Error, Result};
  33use crate::{db::UserId, llm::db::LlmDatabase};
  34use crate::{
  35    db::{
  36        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  37        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  38        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  39    },
  40    stripe_billing::StripeBilling,
  41};
  42
  43pub fn router() -> Router {
  44    Router::new()
  45        .route(
  46            "/billing/preferences",
  47            get(get_billing_preferences).put(update_billing_preferences),
  48        )
  49        .route(
  50            "/billing/subscriptions",
  51            get(list_billing_subscriptions).post(create_billing_subscription),
  52        )
  53        .route(
  54            "/billing/subscriptions/manage",
  55            post(manage_billing_subscription),
  56        )
  57        .route("/billing/monthly_spend", get(get_monthly_spend))
  58        .route("/billing/usage", get(get_current_usage))
  59}
  60
  61#[derive(Debug, Deserialize)]
  62struct GetBillingPreferencesParams {
  63    github_user_id: i32,
  64}
  65
  66#[derive(Debug, Serialize)]
  67struct BillingPreferencesResponse {
  68    max_monthly_llm_usage_spending_in_cents: i32,
  69    model_request_overages_enabled: bool,
  70    model_request_overages_spend_limit_in_cents: i32,
  71}
  72
  73async fn get_billing_preferences(
  74    Extension(app): Extension<Arc<AppState>>,
  75    Query(params): Query<GetBillingPreferencesParams>,
  76) -> Result<Json<BillingPreferencesResponse>> {
  77    let user = app
  78        .db
  79        .get_user_by_github_user_id(params.github_user_id)
  80        .await?
  81        .ok_or_else(|| anyhow!("user not found"))?;
  82
  83    let preferences = app.db.get_billing_preferences(user.id).await?;
  84
  85    Ok(Json(BillingPreferencesResponse {
  86        max_monthly_llm_usage_spending_in_cents: preferences
  87            .as_ref()
  88            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
  89                preferences.max_monthly_llm_usage_spending_in_cents
  90            }),
  91        model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
  92            preferences.model_request_overages_enabled
  93        }),
  94        model_request_overages_spend_limit_in_cents: preferences
  95            .as_ref()
  96            .map_or(0, |preferences| {
  97                preferences.model_request_overages_spend_limit_in_cents
  98            }),
  99    }))
 100}
 101
 102#[derive(Debug, Deserialize)]
 103struct UpdateBillingPreferencesBody {
 104    github_user_id: i32,
 105    #[serde(default)]
 106    max_monthly_llm_usage_spending_in_cents: i32,
 107    #[serde(default)]
 108    model_request_overages_enabled: bool,
 109    #[serde(default)]
 110    model_request_overages_spend_limit_in_cents: i32,
 111}
 112
 113async fn update_billing_preferences(
 114    Extension(app): Extension<Arc<AppState>>,
 115    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
 116    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 117) -> Result<Json<BillingPreferencesResponse>> {
 118    let user = app
 119        .db
 120        .get_user_by_github_user_id(body.github_user_id)
 121        .await?
 122        .ok_or_else(|| anyhow!("user not found"))?;
 123
 124    let max_monthly_llm_usage_spending_in_cents =
 125        body.max_monthly_llm_usage_spending_in_cents.max(0);
 126    let model_request_overages_spend_limit_in_cents =
 127        body.model_request_overages_spend_limit_in_cents.max(0);
 128
 129    let billing_preferences =
 130        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 131            app.db
 132                .update_billing_preferences(
 133                    user.id,
 134                    &UpdateBillingPreferencesParams {
 135                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 136                            max_monthly_llm_usage_spending_in_cents,
 137                        ),
 138                        model_request_overages_enabled: ActiveValue::set(
 139                            body.model_request_overages_enabled,
 140                        ),
 141                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 142                            model_request_overages_spend_limit_in_cents,
 143                        ),
 144                    },
 145                )
 146                .await?
 147        } else {
 148            app.db
 149                .create_billing_preferences(
 150                    user.id,
 151                    &crate::db::CreateBillingPreferencesParams {
 152                        max_monthly_llm_usage_spending_in_cents,
 153                        model_request_overages_enabled: body.model_request_overages_enabled,
 154                        model_request_overages_spend_limit_in_cents,
 155                    },
 156                )
 157                .await?
 158        };
 159
 160    SnowflakeRow::new(
 161        "Billing Preferences Updated",
 162        Some(user.metrics_id),
 163        user.admin,
 164        None,
 165        json!({
 166            "user_id": user.id,
 167            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 168            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 169            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 170        }),
 171    )
 172    .write(&app.kinesis_client, &app.config.kinesis_stream)
 173    .await
 174    .log_err();
 175
 176    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 177
 178    Ok(Json(BillingPreferencesResponse {
 179        max_monthly_llm_usage_spending_in_cents: billing_preferences
 180            .max_monthly_llm_usage_spending_in_cents,
 181        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 182        model_request_overages_spend_limit_in_cents: billing_preferences
 183            .model_request_overages_spend_limit_in_cents,
 184    }))
 185}
 186
 187#[derive(Debug, Deserialize)]
 188struct ListBillingSubscriptionsParams {
 189    github_user_id: i32,
 190}
 191
 192#[derive(Debug, Serialize)]
 193struct BillingSubscriptionJson {
 194    id: BillingSubscriptionId,
 195    name: String,
 196    status: StripeSubscriptionStatus,
 197    trial_end_at: Option<String>,
 198    cancel_at: Option<String>,
 199    /// Whether this subscription can be canceled.
 200    is_cancelable: bool,
 201}
 202
 203#[derive(Debug, Serialize)]
 204struct ListBillingSubscriptionsResponse {
 205    subscriptions: Vec<BillingSubscriptionJson>,
 206}
 207
 208async fn list_billing_subscriptions(
 209    Extension(app): Extension<Arc<AppState>>,
 210    Query(params): Query<ListBillingSubscriptionsParams>,
 211) -> Result<Json<ListBillingSubscriptionsResponse>> {
 212    let user = app
 213        .db
 214        .get_user_by_github_user_id(params.github_user_id)
 215        .await?
 216        .ok_or_else(|| anyhow!("user not found"))?;
 217
 218    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 219
 220    Ok(Json(ListBillingSubscriptionsResponse {
 221        subscriptions: subscriptions
 222            .into_iter()
 223            .map(|subscription| BillingSubscriptionJson {
 224                id: subscription.id,
 225                name: match subscription.kind {
 226                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 227                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
 228                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
 229                    None => "Zed LLM Usage".to_string(),
 230                },
 231                status: subscription.stripe_subscription_status,
 232                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
 233                    maybe!({
 234                        let end_at = subscription.stripe_current_period_end?;
 235                        let end_at = DateTime::from_timestamp(end_at, 0)?;
 236
 237                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
 238                    })
 239                } else {
 240                    None
 241                },
 242                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 243                    cancel_at
 244                        .and_utc()
 245                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 246                }),
 247                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
 248                    && subscription.stripe_cancel_at.is_none(),
 249            })
 250            .collect(),
 251    }))
 252}
 253
 254#[derive(Debug, Clone, Copy, Deserialize)]
 255#[serde(rename_all = "snake_case")]
 256enum ProductCode {
 257    ZedPro,
 258    ZedProTrial,
 259}
 260
 261#[derive(Debug, Deserialize)]
 262struct CreateBillingSubscriptionBody {
 263    github_user_id: i32,
 264    product: Option<ProductCode>,
 265}
 266
 267#[derive(Debug, Serialize)]
 268struct CreateBillingSubscriptionResponse {
 269    checkout_session_url: String,
 270}
 271
 272/// Initiates a Stripe Checkout session for creating a billing subscription.
 273async fn create_billing_subscription(
 274    Extension(app): Extension<Arc<AppState>>,
 275    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 276) -> Result<Json<CreateBillingSubscriptionResponse>> {
 277    let user = app
 278        .db
 279        .get_user_by_github_user_id(body.github_user_id)
 280        .await?
 281        .ok_or_else(|| anyhow!("user not found"))?;
 282
 283    let Some(stripe_client) = app.stripe_client.clone() else {
 284        log::error!("failed to retrieve Stripe client");
 285        Err(Error::http(
 286            StatusCode::NOT_IMPLEMENTED,
 287            "not supported".into(),
 288        ))?
 289    };
 290    let Some(stripe_billing) = app.stripe_billing.clone() else {
 291        log::error!("failed to retrieve Stripe billing object");
 292        Err(Error::http(
 293            StatusCode::NOT_IMPLEMENTED,
 294            "not supported".into(),
 295        ))?
 296    };
 297    let Some(llm_db) = app.llm_db.clone() else {
 298        log::error!("failed to retrieve LLM database");
 299        Err(Error::http(
 300            StatusCode::NOT_IMPLEMENTED,
 301            "not supported".into(),
 302        ))?
 303    };
 304
 305    if app.db.has_active_billing_subscription(user.id).await? {
 306        return Err(Error::http(
 307            StatusCode::CONFLICT,
 308            "user already has an active subscription".into(),
 309        ));
 310    }
 311
 312    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 313    if let Some(existing_billing_customer) = &existing_billing_customer {
 314        if existing_billing_customer.has_overdue_invoices {
 315            return Err(Error::http(
 316                StatusCode::PAYMENT_REQUIRED,
 317                "user has overdue invoices".into(),
 318            ));
 319        }
 320    }
 321
 322    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 323        CustomerId::from_str(&existing_customer.stripe_customer_id)
 324            .context("failed to parse customer ID")?
 325    } else {
 326        let existing_customer = if let Some(email) = user.email_address.as_deref() {
 327            let customers = Customer::list(
 328                &stripe_client,
 329                &stripe::ListCustomers {
 330                    email: Some(email),
 331                    ..Default::default()
 332                },
 333            )
 334            .await?;
 335
 336            customers.data.first().cloned()
 337        } else {
 338            None
 339        };
 340
 341        if let Some(existing_customer) = existing_customer {
 342            existing_customer.id
 343        } else {
 344            let customer = Customer::create(
 345                &stripe_client,
 346                CreateCustomer {
 347                    email: user.email_address.as_deref(),
 348                    ..Default::default()
 349                },
 350            )
 351            .await?;
 352
 353            customer.id
 354        }
 355    };
 356
 357    let success_url = format!(
 358        "{}/account?checkout_complete=1",
 359        app.config.zed_dot_dev_url()
 360    );
 361
 362    let checkout_session_url = match body.product {
 363        Some(ProductCode::ZedPro) => {
 364            stripe_billing
 365                .checkout_with_price(
 366                    app.config.zed_pro_price_id()?,
 367                    customer_id,
 368                    &user.github_login,
 369                    &success_url,
 370                )
 371                .await?
 372        }
 373        Some(ProductCode::ZedProTrial) => {
 374            if let Some(existing_billing_customer) = &existing_billing_customer {
 375                if existing_billing_customer.trial_started_at.is_some() {
 376                    return Err(Error::http(
 377                        StatusCode::FORBIDDEN,
 378                        "user already used free trial".into(),
 379                    ));
 380                }
 381            }
 382
 383            let feature_flags = app.db.get_user_flags(user.id).await?;
 384
 385            stripe_billing
 386                .checkout_with_zed_pro_trial(
 387                    app.config.zed_pro_price_id()?,
 388                    customer_id,
 389                    &user.github_login,
 390                    feature_flags,
 391                    &success_url,
 392                )
 393                .await?
 394        }
 395        None => {
 396            let default_model = llm_db.model(
 397                zed_llm_client::LanguageModelProvider::Anthropic,
 398                "claude-3-7-sonnet",
 399            )?;
 400            let stripe_model = stripe_billing
 401                .register_model_for_token_based_usage(default_model)
 402                .await?;
 403            stripe_billing
 404                .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
 405                .await?
 406        }
 407    };
 408
 409    Ok(Json(CreateBillingSubscriptionResponse {
 410        checkout_session_url,
 411    }))
 412}
 413
 414#[derive(Debug, PartialEq, Deserialize)]
 415#[serde(rename_all = "snake_case")]
 416enum ManageSubscriptionIntent {
 417    /// The user intends to manage their subscription.
 418    ///
 419    /// This will open the Stripe billing portal without putting the user in a specific flow.
 420    ManageSubscription,
 421    /// The user intends to upgrade to Zed Pro.
 422    UpgradeToPro,
 423    /// The user intends to cancel their subscription.
 424    Cancel,
 425    /// The user intends to stop the cancellation of their subscription.
 426    StopCancellation,
 427}
 428
 429#[derive(Debug, Deserialize)]
 430struct ManageBillingSubscriptionBody {
 431    github_user_id: i32,
 432    intent: ManageSubscriptionIntent,
 433    /// The ID of the subscription to manage.
 434    subscription_id: BillingSubscriptionId,
 435}
 436
 437#[derive(Debug, Serialize)]
 438struct ManageBillingSubscriptionResponse {
 439    billing_portal_session_url: Option<String>,
 440}
 441
 442/// Initiates a Stripe customer portal session for managing a billing subscription.
 443async fn manage_billing_subscription(
 444    Extension(app): Extension<Arc<AppState>>,
 445    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 446) -> Result<Json<ManageBillingSubscriptionResponse>> {
 447    let user = app
 448        .db
 449        .get_user_by_github_user_id(body.github_user_id)
 450        .await?
 451        .ok_or_else(|| anyhow!("user not found"))?;
 452
 453    let Some(stripe_client) = app.stripe_client.clone() else {
 454        log::error!("failed to retrieve Stripe client");
 455        Err(Error::http(
 456            StatusCode::NOT_IMPLEMENTED,
 457            "not supported".into(),
 458        ))?
 459    };
 460
 461    let customer = app
 462        .db
 463        .get_billing_customer_by_user_id(user.id)
 464        .await?
 465        .ok_or_else(|| anyhow!("billing customer not found"))?;
 466    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 467        .context("failed to parse customer ID")?;
 468
 469    let subscription = app
 470        .db
 471        .get_billing_subscription_by_id(body.subscription_id)
 472        .await?
 473        .ok_or_else(|| anyhow!("subscription not found"))?;
 474    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 475        .context("failed to parse subscription ID")?;
 476
 477    if body.intent == ManageSubscriptionIntent::StopCancellation {
 478        let updated_stripe_subscription = Subscription::update(
 479            &stripe_client,
 480            &subscription_id,
 481            stripe::UpdateSubscription {
 482                cancel_at_period_end: Some(false),
 483                ..Default::default()
 484            },
 485        )
 486        .await?;
 487
 488        app.db
 489            .update_billing_subscription(
 490                subscription.id,
 491                &UpdateBillingSubscriptionParams {
 492                    stripe_cancel_at: ActiveValue::set(
 493                        updated_stripe_subscription
 494                            .cancel_at
 495                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 496                            .map(|time| time.naive_utc()),
 497                    ),
 498                    ..Default::default()
 499                },
 500            )
 501            .await?;
 502
 503        return Ok(Json(ManageBillingSubscriptionResponse {
 504            billing_portal_session_url: None,
 505        }));
 506    }
 507
 508    let flow = match body.intent {
 509        ManageSubscriptionIntent::ManageSubscription => None,
 510        ManageSubscriptionIntent::UpgradeToPro => {
 511            let zed_pro_price_id = app.config.zed_pro_price_id()?;
 512            let zed_free_price_id = app.config.zed_free_price_id()?;
 513
 514            let stripe_subscription =
 515                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 516
 517            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 518                && stripe_subscription.items.data.iter().any(|item| {
 519                    item.price
 520                        .as_ref()
 521                        .map_or(false, |price| price.id == zed_pro_price_id)
 522                });
 523            if is_on_zed_pro_trial {
 524                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 525                Subscription::update(
 526                    &stripe_client,
 527                    &stripe_subscription.id,
 528                    stripe::UpdateSubscription {
 529                        trial_end: Some(stripe::Scheduled::now()),
 530                        ..Default::default()
 531                    },
 532                )
 533                .await?;
 534
 535                return Ok(Json(ManageBillingSubscriptionResponse {
 536                    billing_portal_session_url: None,
 537                }));
 538            }
 539
 540            let subscription_item_to_update = stripe_subscription
 541                .items
 542                .data
 543                .iter()
 544                .find_map(|item| {
 545                    let price = item.price.as_ref()?;
 546
 547                    if price.id == zed_free_price_id {
 548                        Some(item.id.clone())
 549                    } else {
 550                        None
 551                    }
 552                })
 553                .ok_or_else(|| anyhow!("No subscription item to update"))?;
 554
 555            Some(CreateBillingPortalSessionFlowData {
 556                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 557                subscription_update_confirm: Some(
 558                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 559                        subscription: subscription.stripe_subscription_id,
 560                        items: vec![
 561                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 562                                id: subscription_item_to_update.to_string(),
 563                                price: Some(zed_pro_price_id.to_string()),
 564                                quantity: Some(1),
 565                            },
 566                        ],
 567                        discounts: None,
 568                    },
 569                ),
 570                ..Default::default()
 571            })
 572        }
 573        ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
 574            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 575            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 576                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 577                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 578                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 579                }),
 580                ..Default::default()
 581            }),
 582            subscription_cancel: Some(
 583                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 584                    subscription: subscription.stripe_subscription_id,
 585                    retention: None,
 586                },
 587            ),
 588            ..Default::default()
 589        }),
 590        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 591    };
 592
 593    let mut params = CreateBillingPortalSession::new(customer_id);
 594    params.flow_data = flow;
 595    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 596    params.return_url = Some(&return_url);
 597
 598    let session = BillingPortalSession::create(&stripe_client, params).await?;
 599
 600    Ok(Json(ManageBillingSubscriptionResponse {
 601        billing_portal_session_url: Some(session.url),
 602    }))
 603}
 604
 605/// The amount of time we wait in between each poll of Stripe events.
 606///
 607/// This value should strike a balance between:
 608///   1. Being short enough that we update quickly when something in Stripe changes
 609///   2. Being long enough that we don't eat into our rate limits.
 610///
 611/// As a point of reference, the Sequin folks say they have this at **500ms**:
 612///
 613/// > We poll the Stripe /events endpoint every 500ms per account
 614/// >
 615/// > — https://blog.sequinstream.com/events-not-webhooks/
 616const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 617
 618/// The maximum number of events to return per page.
 619///
 620/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 621///
 622/// > Limit can range between 1 and 100, and the default is 10.
 623const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 624
 625/// The number of pages consisting entirely of already-processed events that we
 626/// will see before we stop retrieving events.
 627///
 628/// This is used to prevent over-fetching the Stripe events API for events we've
 629/// already seen and processed.
 630const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 631
 632/// Polls the Stripe events API periodically to reconcile the records in our
 633/// database with the data in Stripe.
 634pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 635    let Some(stripe_client) = app.stripe_client.clone() else {
 636        log::warn!("failed to retrieve Stripe client");
 637        return;
 638    };
 639
 640    let executor = app.executor.clone();
 641    executor.spawn_detached({
 642        let executor = executor.clone();
 643        async move {
 644            loop {
 645                poll_stripe_events(&app, &rpc_server, &stripe_client)
 646                    .await
 647                    .log_err();
 648
 649                executor.sleep(POLL_EVENTS_INTERVAL).await;
 650            }
 651        }
 652    });
 653}
 654
 655async fn poll_stripe_events(
 656    app: &Arc<AppState>,
 657    rpc_server: &Arc<Server>,
 658    stripe_client: &stripe::Client,
 659) -> anyhow::Result<()> {
 660    fn event_type_to_string(event_type: EventType) -> String {
 661        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 662        // so we need to unquote it.
 663        event_type.to_string().trim_matches('"').to_string()
 664    }
 665
 666    let event_types = [
 667        EventType::CustomerCreated,
 668        EventType::CustomerUpdated,
 669        EventType::CustomerSubscriptionCreated,
 670        EventType::CustomerSubscriptionUpdated,
 671        EventType::CustomerSubscriptionPaused,
 672        EventType::CustomerSubscriptionResumed,
 673        EventType::CustomerSubscriptionDeleted,
 674    ]
 675    .into_iter()
 676    .map(event_type_to_string)
 677    .collect::<Vec<_>>();
 678
 679    let mut pages_of_already_processed_events = 0;
 680    let mut unprocessed_events = Vec::new();
 681
 682    log::info!(
 683        "Stripe events: starting retrieval for {}",
 684        event_types.join(", ")
 685    );
 686    let mut params = ListEvents::new();
 687    params.types = Some(event_types.clone());
 688    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 689
 690    let mut event_pages = stripe::Event::list(&stripe_client, &params)
 691        .await?
 692        .paginate(params);
 693
 694    loop {
 695        let processed_event_ids = {
 696            let event_ids = event_pages
 697                .page
 698                .data
 699                .iter()
 700                .map(|event| event.id.as_str())
 701                .collect::<Vec<_>>();
 702            app.db
 703                .get_processed_stripe_events_by_event_ids(&event_ids)
 704                .await?
 705                .into_iter()
 706                .map(|event| event.stripe_event_id)
 707                .collect::<Vec<_>>()
 708        };
 709
 710        let mut processed_events_in_page = 0;
 711        let events_in_page = event_pages.page.data.len();
 712        for event in &event_pages.page.data {
 713            if processed_event_ids.contains(&event.id.to_string()) {
 714                processed_events_in_page += 1;
 715                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 716            } else {
 717                unprocessed_events.push(event.clone());
 718            }
 719        }
 720
 721        if processed_events_in_page == events_in_page {
 722            pages_of_already_processed_events += 1;
 723        }
 724
 725        if event_pages.page.has_more {
 726            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 727            {
 728                log::info!(
 729                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 730                );
 731                break;
 732            } else {
 733                log::info!("Stripe events: retrieving next page");
 734                event_pages = event_pages.next(&stripe_client).await?;
 735            }
 736        } else {
 737            break;
 738        }
 739    }
 740
 741    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 742
 743    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 744    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 745
 746    for event in unprocessed_events {
 747        let event_id = event.id.clone();
 748        let processed_event_params = CreateProcessedStripeEventParams {
 749            stripe_event_id: event.id.to_string(),
 750            stripe_event_type: event_type_to_string(event.type_),
 751            stripe_event_created_timestamp: event.created,
 752        };
 753
 754        // If the event has happened too far in the past, we don't want to
 755        // process it and risk overwriting other more-recent updates.
 756        //
 757        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 758        let one_day = Duration::from_secs(24 * 60 * 60);
 759        let a_day_ago = Utc::now() - one_day;
 760        if a_day_ago.timestamp() > event.created {
 761            log::info!(
 762                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 763                event_id
 764            );
 765            app.db
 766                .create_processed_stripe_event(&processed_event_params)
 767                .await?;
 768
 769            return Ok(());
 770        }
 771
 772        let process_result = match event.type_ {
 773            EventType::CustomerCreated | EventType::CustomerUpdated => {
 774                handle_customer_event(app, stripe_client, event).await
 775            }
 776            EventType::CustomerSubscriptionCreated
 777            | EventType::CustomerSubscriptionUpdated
 778            | EventType::CustomerSubscriptionPaused
 779            | EventType::CustomerSubscriptionResumed
 780            | EventType::CustomerSubscriptionDeleted => {
 781                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 782            }
 783            _ => Ok(()),
 784        };
 785
 786        if let Some(()) = process_result
 787            .with_context(|| format!("failed to process event {event_id} successfully"))
 788            .log_err()
 789        {
 790            app.db
 791                .create_processed_stripe_event(&processed_event_params)
 792                .await?;
 793        }
 794    }
 795
 796    Ok(())
 797}
 798
 799async fn handle_customer_event(
 800    app: &Arc<AppState>,
 801    _stripe_client: &stripe::Client,
 802    event: stripe::Event,
 803) -> anyhow::Result<()> {
 804    let EventObject::Customer(customer) = event.data.object else {
 805        bail!("unexpected event payload for {}", event.id);
 806    };
 807
 808    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 809
 810    let Some(email) = customer.email else {
 811        log::info!("Stripe customer has no email: skipping");
 812        return Ok(());
 813    };
 814
 815    let Some(user) = app.db.get_user_by_email(&email).await? else {
 816        log::info!("no user found for email: skipping");
 817        return Ok(());
 818    };
 819
 820    if let Some(existing_customer) = app
 821        .db
 822        .get_billing_customer_by_stripe_customer_id(&customer.id)
 823        .await?
 824    {
 825        app.db
 826            .update_billing_customer(
 827                existing_customer.id,
 828                &UpdateBillingCustomerParams {
 829                    // For now we just leave the information as-is, as it is not
 830                    // likely to change.
 831                    ..Default::default()
 832                },
 833            )
 834            .await?;
 835    } else {
 836        app.db
 837            .create_billing_customer(&CreateBillingCustomerParams {
 838                user_id: user.id,
 839                stripe_customer_id: customer.id.to_string(),
 840            })
 841            .await?;
 842    }
 843
 844    Ok(())
 845}
 846
 847async fn handle_customer_subscription_event(
 848    app: &Arc<AppState>,
 849    rpc_server: &Arc<Server>,
 850    stripe_client: &stripe::Client,
 851    event: stripe::Event,
 852) -> anyhow::Result<()> {
 853    let EventObject::Subscription(subscription) = event.data.object else {
 854        bail!("unexpected event payload for {}", event.id);
 855    };
 856
 857    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 858
 859    let subscription_kind = maybe!({
 860        let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
 861        let zed_free_price_id = app.config.zed_free_price_id().ok()?;
 862
 863        subscription.items.data.iter().find_map(|item| {
 864            let price = item.price.as_ref()?;
 865
 866            if price.id == zed_pro_price_id {
 867                Some(if subscription.status == SubscriptionStatus::Trialing {
 868                    SubscriptionKind::ZedProTrial
 869                } else {
 870                    SubscriptionKind::ZedPro
 871                })
 872            } else if price.id == zed_free_price_id {
 873                Some(SubscriptionKind::ZedFree)
 874            } else {
 875                None
 876            }
 877        })
 878    });
 879
 880    let billing_customer =
 881        find_or_create_billing_customer(app, stripe_client, subscription.customer)
 882            .await?
 883            .ok_or_else(|| anyhow!("billing customer not found"))?;
 884
 885    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
 886        if subscription.status == SubscriptionStatus::Trialing {
 887            let current_period_start =
 888                DateTime::from_timestamp(subscription.current_period_start, 0)
 889                    .ok_or_else(|| anyhow!("No trial subscription period start"))?;
 890
 891            app.db
 892                .update_billing_customer(
 893                    billing_customer.id,
 894                    &UpdateBillingCustomerParams {
 895                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
 896                        ..Default::default()
 897                    },
 898                )
 899                .await?;
 900        }
 901    }
 902
 903    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 904        && subscription
 905            .cancellation_details
 906            .as_ref()
 907            .and_then(|details| details.reason)
 908            .map_or(false, |reason| {
 909                reason == CancellationDetailsReason::PaymentFailed
 910            });
 911
 912    if was_canceled_due_to_payment_failure {
 913        app.db
 914            .update_billing_customer(
 915                billing_customer.id,
 916                &UpdateBillingCustomerParams {
 917                    has_overdue_invoices: ActiveValue::set(true),
 918                    ..Default::default()
 919                },
 920            )
 921            .await?;
 922    }
 923
 924    if let Some(existing_subscription) = app
 925        .db
 926        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
 927        .await?
 928    {
 929        let llm_db = app
 930            .llm_db
 931            .clone()
 932            .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
 933
 934        let new_period_start_at =
 935            chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
 936                .ok_or_else(|| anyhow!("No subscription period start"))?;
 937        let new_period_end_at =
 938            chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
 939                .ok_or_else(|| anyhow!("No subscription period end"))?;
 940
 941        llm_db
 942            .transfer_existing_subscription_usage(
 943                billing_customer.user_id,
 944                &existing_subscription,
 945                subscription_kind,
 946                new_period_start_at,
 947                new_period_end_at,
 948            )
 949            .await?;
 950
 951        app.db
 952            .update_billing_subscription(
 953                existing_subscription.id,
 954                &UpdateBillingSubscriptionParams {
 955                    billing_customer_id: ActiveValue::set(billing_customer.id),
 956                    kind: ActiveValue::set(subscription_kind),
 957                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
 958                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
 959                    stripe_cancel_at: ActiveValue::set(
 960                        subscription
 961                            .cancel_at
 962                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 963                            .map(|time| time.naive_utc()),
 964                    ),
 965                    stripe_cancellation_reason: ActiveValue::set(
 966                        subscription
 967                            .cancellation_details
 968                            .and_then(|details| details.reason)
 969                            .map(|reason| reason.into()),
 970                    ),
 971                    stripe_current_period_start: ActiveValue::set(Some(
 972                        subscription.current_period_start,
 973                    )),
 974                    stripe_current_period_end: ActiveValue::set(Some(
 975                        subscription.current_period_end,
 976                    )),
 977                },
 978            )
 979            .await?;
 980    } else {
 981        // If the user already has an active billing subscription, ignore the
 982        // event and return an `Ok` to signal that it was processed
 983        // successfully.
 984        //
 985        // There is the possibility that this could cause us to not create a
 986        // subscription in the following scenario:
 987        //
 988        //   1. User has an active subscription A
 989        //   2. User cancels subscription A
 990        //   3. User creates a new subscription B
 991        //   4. We process the new subscription B before the cancellation of subscription A
 992        //   5. User ends up with no subscriptions
 993        //
 994        // In theory this situation shouldn't arise as we try to process the events in the order they occur.
 995        if app
 996            .db
 997            .has_active_billing_subscription(billing_customer.user_id)
 998            .await?
 999        {
1000            log::info!(
1001                "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1002                user_id = billing_customer.user_id,
1003                subscription_id = subscription.id
1004            );
1005            return Ok(());
1006        }
1007
1008        app.db
1009            .create_billing_subscription(&CreateBillingSubscriptionParams {
1010                billing_customer_id: billing_customer.id,
1011                kind: subscription_kind,
1012                stripe_subscription_id: subscription.id.to_string(),
1013                stripe_subscription_status: subscription.status.into(),
1014                stripe_cancellation_reason: subscription
1015                    .cancellation_details
1016                    .and_then(|details| details.reason)
1017                    .map(|reason| reason.into()),
1018                stripe_current_period_start: Some(subscription.current_period_start),
1019                stripe_current_period_end: Some(subscription.current_period_end),
1020            })
1021            .await?;
1022    }
1023
1024    // When the user's subscription changes, we want to refresh their LLM tokens
1025    // to either grant/revoke access.
1026    rpc_server
1027        .refresh_llm_tokens_for_user(billing_customer.user_id)
1028        .await;
1029
1030    Ok(())
1031}
1032
1033#[derive(Debug, Deserialize)]
1034struct GetMonthlySpendParams {
1035    github_user_id: i32,
1036}
1037
1038#[derive(Debug, Serialize)]
1039struct GetMonthlySpendResponse {
1040    monthly_free_tier_spend_in_cents: u32,
1041    monthly_free_tier_allowance_in_cents: u32,
1042    monthly_spend_in_cents: u32,
1043}
1044
1045async fn get_monthly_spend(
1046    Extension(app): Extension<Arc<AppState>>,
1047    Query(params): Query<GetMonthlySpendParams>,
1048) -> Result<Json<GetMonthlySpendResponse>> {
1049    let user = app
1050        .db
1051        .get_user_by_github_user_id(params.github_user_id)
1052        .await?
1053        .ok_or_else(|| anyhow!("user not found"))?;
1054
1055    let Some(llm_db) = app.llm_db.clone() else {
1056        return Err(Error::http(
1057            StatusCode::NOT_IMPLEMENTED,
1058            "LLM database not available".into(),
1059        ));
1060    };
1061
1062    let free_tier = user
1063        .custom_llm_monthly_allowance_in_cents
1064        .map(|allowance| Cents(allowance as u32))
1065        .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1066
1067    let spending_for_month = llm_db
1068        .get_user_spending_for_month(user.id, Utc::now())
1069        .await?;
1070
1071    let free_tier_spend = Cents::min(spending_for_month, free_tier);
1072    let monthly_spend = spending_for_month.saturating_sub(free_tier);
1073
1074    Ok(Json(GetMonthlySpendResponse {
1075        monthly_free_tier_spend_in_cents: free_tier_spend.0,
1076        monthly_free_tier_allowance_in_cents: free_tier.0,
1077        monthly_spend_in_cents: monthly_spend.0,
1078    }))
1079}
1080
1081#[derive(Debug, Deserialize)]
1082struct GetCurrentUsageParams {
1083    github_user_id: i32,
1084}
1085
1086#[derive(Debug, Serialize)]
1087struct UsageCounts {
1088    pub used: i32,
1089    pub limit: Option<i32>,
1090    pub remaining: Option<i32>,
1091}
1092
1093#[derive(Debug, Serialize)]
1094struct ModelRequestUsage {
1095    pub model: String,
1096    pub mode: CompletionMode,
1097    pub requests: i32,
1098}
1099
1100#[derive(Debug, Serialize)]
1101struct GetCurrentUsageResponse {
1102    pub model_requests: UsageCounts,
1103    pub model_request_usage: Vec<ModelRequestUsage>,
1104    pub edit_predictions: UsageCounts,
1105}
1106
1107async fn get_current_usage(
1108    Extension(app): Extension<Arc<AppState>>,
1109    Query(params): Query<GetCurrentUsageParams>,
1110) -> Result<Json<GetCurrentUsageResponse>> {
1111    let user = app
1112        .db
1113        .get_user_by_github_user_id(params.github_user_id)
1114        .await?
1115        .ok_or_else(|| anyhow!("user not found"))?;
1116
1117    let Some(llm_db) = app.llm_db.clone() else {
1118        return Err(Error::http(
1119            StatusCode::NOT_IMPLEMENTED,
1120            "LLM database not available".into(),
1121        ));
1122    };
1123
1124    let empty_usage = GetCurrentUsageResponse {
1125        model_requests: UsageCounts {
1126            used: 0,
1127            limit: Some(0),
1128            remaining: Some(0),
1129        },
1130        model_request_usage: Vec::new(),
1131        edit_predictions: UsageCounts {
1132            used: 0,
1133            limit: Some(0),
1134            remaining: Some(0),
1135        },
1136    };
1137
1138    let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1139        return Ok(Json(empty_usage));
1140    };
1141
1142    let subscription_period = maybe!({
1143        let period_start_at = subscription.current_period_start_at()?;
1144        let period_end_at = subscription.current_period_end_at()?;
1145
1146        Some((period_start_at, period_end_at))
1147    });
1148
1149    let Some((period_start_at, period_end_at)) = subscription_period else {
1150        return Ok(Json(empty_usage));
1151    };
1152
1153    let usage = llm_db
1154        .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1155        .await?;
1156    let Some(usage) = usage else {
1157        return Ok(Json(empty_usage));
1158    };
1159
1160    let plan = match usage.plan {
1161        SubscriptionKind::ZedPro => zed_llm_client::Plan::ZedPro,
1162        SubscriptionKind::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
1163        SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
1164    };
1165
1166    let model_requests_limit = match plan.model_requests_limit() {
1167        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1168        zed_llm_client::UsageLimit::Unlimited => None,
1169    };
1170    let edit_prediction_limit = match plan.edit_predictions_limit() {
1171        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1172        zed_llm_client::UsageLimit::Unlimited => None,
1173    };
1174
1175    let subscription_usage_meters = llm_db
1176        .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1177        .await?;
1178
1179    let model_request_usage = subscription_usage_meters
1180        .into_iter()
1181        .filter_map(|(usage_meter, _usage)| {
1182            let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1183
1184            Some(ModelRequestUsage {
1185                model: model.name.clone(),
1186                mode: usage_meter.mode,
1187                requests: usage_meter.requests,
1188            })
1189        })
1190        .collect::<Vec<_>>();
1191
1192    Ok(Json(GetCurrentUsageResponse {
1193        model_requests: UsageCounts {
1194            used: usage.model_requests,
1195            limit: model_requests_limit,
1196            remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1197        },
1198        model_request_usage,
1199        edit_predictions: UsageCounts {
1200            used: usage.edit_predictions,
1201            limit: edit_prediction_limit,
1202            remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
1203        },
1204    }))
1205}
1206
1207impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1208    fn from(value: SubscriptionStatus) -> Self {
1209        match value {
1210            SubscriptionStatus::Incomplete => Self::Incomplete,
1211            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1212            SubscriptionStatus::Trialing => Self::Trialing,
1213            SubscriptionStatus::Active => Self::Active,
1214            SubscriptionStatus::PastDue => Self::PastDue,
1215            SubscriptionStatus::Canceled => Self::Canceled,
1216            SubscriptionStatus::Unpaid => Self::Unpaid,
1217            SubscriptionStatus::Paused => Self::Paused,
1218        }
1219    }
1220}
1221
1222impl From<CancellationDetailsReason> for StripeCancellationReason {
1223    fn from(value: CancellationDetailsReason) -> Self {
1224        match value {
1225            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1226            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1227            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1228        }
1229    }
1230}
1231
1232/// Finds or creates a billing customer using the provided customer.
1233async fn find_or_create_billing_customer(
1234    app: &Arc<AppState>,
1235    stripe_client: &stripe::Client,
1236    customer_or_id: Expandable<Customer>,
1237) -> anyhow::Result<Option<billing_customer::Model>> {
1238    let customer_id = match &customer_or_id {
1239        Expandable::Id(id) => id,
1240        Expandable::Object(customer) => customer.id.as_ref(),
1241    };
1242
1243    // If we already have a billing customer record associated with the Stripe customer,
1244    // there's nothing more we need to do.
1245    if let Some(billing_customer) = app
1246        .db
1247        .get_billing_customer_by_stripe_customer_id(customer_id)
1248        .await?
1249    {
1250        return Ok(Some(billing_customer));
1251    }
1252
1253    // If all we have is a customer ID, resolve it to a full customer record by
1254    // hitting the Stripe API.
1255    let customer = match customer_or_id {
1256        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1257        Expandable::Object(customer) => *customer,
1258    };
1259
1260    let Some(email) = customer.email else {
1261        return Ok(None);
1262    };
1263
1264    let Some(user) = app.db.get_user_by_email(&email).await? else {
1265        return Ok(None);
1266    };
1267
1268    let billing_customer = app
1269        .db
1270        .create_billing_customer(&CreateBillingCustomerParams {
1271            user_id: user.id,
1272            stripe_customer_id: customer.id.to_string(),
1273        })
1274        .await?;
1275
1276    Ok(Some(billing_customer))
1277}
1278
1279const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1280
1281pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
1282    let Some(stripe_billing) = app.stripe_billing.clone() else {
1283        log::warn!("failed to retrieve Stripe billing object");
1284        return;
1285    };
1286    let Some(llm_db) = app.llm_db.clone() else {
1287        log::warn!("failed to retrieve LLM database");
1288        return;
1289    };
1290
1291    let executor = app.executor.clone();
1292    executor.spawn_detached({
1293        let executor = executor.clone();
1294        async move {
1295            loop {
1296                sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
1297                    .await
1298                    .context("failed to sync LLM usage to Stripe")
1299                    .trace_err();
1300                executor
1301                    .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
1302                    .await;
1303            }
1304        }
1305    });
1306}
1307
1308async fn sync_token_usage_with_stripe(
1309    app: &Arc<AppState>,
1310    llm_db: &Arc<LlmDatabase>,
1311    stripe_billing: &Arc<StripeBilling>,
1312) -> anyhow::Result<()> {
1313    let events = llm_db.get_billing_events().await?;
1314    let user_ids = events
1315        .iter()
1316        .map(|(event, _)| event.user_id)
1317        .collect::<HashSet<UserId>>();
1318    let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
1319
1320    for (event, model) in events {
1321        let Some((stripe_db_customer, stripe_db_subscription)) =
1322            stripe_subscriptions.get(&event.user_id)
1323        else {
1324            tracing::warn!(
1325                user_id = event.user_id.0,
1326                "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
1327            );
1328            continue;
1329        };
1330        let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
1331            .stripe_subscription_id
1332            .parse()
1333            .context("failed to parse stripe subscription id from db")?;
1334        let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1335            .stripe_customer_id
1336            .parse()
1337            .context("failed to parse stripe customer id from db")?;
1338
1339        let stripe_model = stripe_billing
1340            .register_model_for_token_based_usage(&model)
1341            .await?;
1342        stripe_billing
1343            .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1344            .await?;
1345        stripe_billing
1346            .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
1347            .await?;
1348        llm_db.consume_billing_event(event.id).await?;
1349    }
1350
1351    Ok(())
1352}
1353
1354const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1355
1356pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1357    let Some(stripe_billing) = app.stripe_billing.clone() else {
1358        log::warn!("failed to retrieve Stripe billing object");
1359        return;
1360    };
1361    let Some(llm_db) = app.llm_db.clone() else {
1362        log::warn!("failed to retrieve LLM database");
1363        return;
1364    };
1365
1366    let executor = app.executor.clone();
1367    executor.spawn_detached({
1368        let executor = executor.clone();
1369        async move {
1370            loop {
1371                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1372                    .await
1373                    .context("failed to sync LLM request usage to Stripe")
1374                    .trace_err();
1375                executor
1376                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1377                    .await;
1378            }
1379        }
1380    });
1381}
1382
1383async fn sync_model_request_usage_with_stripe(
1384    app: &Arc<AppState>,
1385    llm_db: &Arc<LlmDatabase>,
1386    stripe_billing: &Arc<StripeBilling>,
1387) -> anyhow::Result<()> {
1388    let usage_meters = llm_db
1389        .get_current_subscription_usage_meters(Utc::now())
1390        .await?;
1391    let user_ids = usage_meters
1392        .iter()
1393        .map(|(_, usage)| usage.user_id)
1394        .collect::<HashSet<UserId>>();
1395    let billing_subscriptions = app
1396        .db
1397        .get_active_zed_pro_billing_subscriptions(user_ids)
1398        .await?;
1399
1400    let claude_3_5_sonnet = stripe_billing
1401        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1402        .await?;
1403    let claude_3_7_sonnet = stripe_billing
1404        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1405        .await?;
1406    let claude_3_7_sonnet_max = stripe_billing
1407        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1408        .await?;
1409
1410    for (usage_meter, usage) in usage_meters {
1411        maybe!(async {
1412            let Some((billing_customer, billing_subscription)) =
1413                billing_subscriptions.get(&usage.user_id)
1414            else {
1415                bail!(
1416                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1417                    usage.user_id
1418                );
1419            };
1420
1421            let stripe_customer_id = billing_customer
1422                .stripe_customer_id
1423                .parse::<stripe::CustomerId>()
1424                .context("failed to parse Stripe customer ID from database")?;
1425            let stripe_subscription_id = billing_subscription
1426                .stripe_subscription_id
1427                .parse::<stripe::SubscriptionId>()
1428                .context("failed to parse Stripe subscription ID from database")?;
1429
1430            let model = llm_db.model_by_id(usage_meter.model_id)?;
1431
1432            let (price_id, meter_event_name) = match model.name.as_str() {
1433                "claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
1434                "claude-3-7-sonnet" => match usage_meter.mode {
1435                    CompletionMode::Normal => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
1436                    CompletionMode::Max => {
1437                        (&claude_3_7_sonnet_max.id, "claude_3_7_sonnet/requests/max")
1438                    }
1439                },
1440                model_name => {
1441                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1442                }
1443            };
1444
1445            stripe_billing
1446                .subscribe_to_price(&stripe_subscription_id, price_id)
1447                .await?;
1448            stripe_billing
1449                .bill_model_request_usage(
1450                    &stripe_customer_id,
1451                    meter_event_name,
1452                    usage_meter.requests,
1453                )
1454                .await?;
1455
1456            Ok(())
1457        })
1458        .await
1459        .log_err();
1460    }
1461
1462    Ok(())
1463}