billing.rs

   1use anyhow::{Context as _, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  19    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  20    CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
  21    PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
  22};
  23use util::{ResultExt, maybe};
  24
  25use crate::api::events::SnowflakeRow;
  26use crate::db::billing_subscription::{
  27    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  28};
  29use crate::llm::db::subscription_usage_meter::CompletionMode;
  30use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
  31use crate::rpc::{ResultExt as _, Server};
  32use crate::stripe_client::{
  33    StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
  34    StripeSubscriptionId, UpdateCustomerParams,
  35};
  36use crate::{AppState, Error, Result};
  37use crate::{db::UserId, llm::db::LlmDatabase};
  38use crate::{
  39    db::{
  40        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  41        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  42        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  43    },
  44    stripe_billing::StripeBilling,
  45};
  46
  47pub fn router() -> Router {
  48    Router::new()
  49        .route(
  50            "/billing/preferences",
  51            get(get_billing_preferences).put(update_billing_preferences),
  52        )
  53        .route(
  54            "/billing/subscriptions",
  55            get(list_billing_subscriptions).post(create_billing_subscription),
  56        )
  57        .route(
  58            "/billing/subscriptions/manage",
  59            post(manage_billing_subscription),
  60        )
  61        .route(
  62            "/billing/subscriptions/sync",
  63            post(sync_billing_subscription),
  64        )
  65        .route("/billing/usage", get(get_current_usage))
  66}
  67
  68#[derive(Debug, Deserialize)]
  69struct GetBillingPreferencesParams {
  70    github_user_id: i32,
  71}
  72
  73#[derive(Debug, Serialize)]
  74struct BillingPreferencesResponse {
  75    trial_started_at: Option<String>,
  76    max_monthly_llm_usage_spending_in_cents: i32,
  77    model_request_overages_enabled: bool,
  78    model_request_overages_spend_limit_in_cents: i32,
  79}
  80
  81async fn get_billing_preferences(
  82    Extension(app): Extension<Arc<AppState>>,
  83    Query(params): Query<GetBillingPreferencesParams>,
  84) -> Result<Json<BillingPreferencesResponse>> {
  85    let user = app
  86        .db
  87        .get_user_by_github_user_id(params.github_user_id)
  88        .await?
  89        .context("user not found")?;
  90
  91    let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
  92    let preferences = app.db.get_billing_preferences(user.id).await?;
  93
  94    Ok(Json(BillingPreferencesResponse {
  95        trial_started_at: billing_customer
  96            .and_then(|billing_customer| billing_customer.trial_started_at)
  97            .map(|trial_started_at| {
  98                trial_started_at
  99                    .and_utc()
 100                    .to_rfc3339_opts(SecondsFormat::Millis, true)
 101            }),
 102        max_monthly_llm_usage_spending_in_cents: preferences
 103            .as_ref()
 104            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
 105                preferences.max_monthly_llm_usage_spending_in_cents
 106            }),
 107        model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
 108            preferences.model_request_overages_enabled
 109        }),
 110        model_request_overages_spend_limit_in_cents: preferences
 111            .as_ref()
 112            .map_or(0, |preferences| {
 113                preferences.model_request_overages_spend_limit_in_cents
 114            }),
 115    }))
 116}
 117
 118#[derive(Debug, Deserialize)]
 119struct UpdateBillingPreferencesBody {
 120    github_user_id: i32,
 121    #[serde(default)]
 122    max_monthly_llm_usage_spending_in_cents: i32,
 123    #[serde(default)]
 124    model_request_overages_enabled: bool,
 125    #[serde(default)]
 126    model_request_overages_spend_limit_in_cents: i32,
 127}
 128
 129async fn update_billing_preferences(
 130    Extension(app): Extension<Arc<AppState>>,
 131    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
 132    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 133) -> Result<Json<BillingPreferencesResponse>> {
 134    let user = app
 135        .db
 136        .get_user_by_github_user_id(body.github_user_id)
 137        .await?
 138        .context("user not found")?;
 139
 140    let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 141
 142    let max_monthly_llm_usage_spending_in_cents =
 143        body.max_monthly_llm_usage_spending_in_cents.max(0);
 144    let model_request_overages_spend_limit_in_cents =
 145        body.model_request_overages_spend_limit_in_cents.max(0);
 146
 147    let billing_preferences =
 148        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 149            app.db
 150                .update_billing_preferences(
 151                    user.id,
 152                    &UpdateBillingPreferencesParams {
 153                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 154                            max_monthly_llm_usage_spending_in_cents,
 155                        ),
 156                        model_request_overages_enabled: ActiveValue::set(
 157                            body.model_request_overages_enabled,
 158                        ),
 159                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 160                            model_request_overages_spend_limit_in_cents,
 161                        ),
 162                    },
 163                )
 164                .await?
 165        } else {
 166            app.db
 167                .create_billing_preferences(
 168                    user.id,
 169                    &crate::db::CreateBillingPreferencesParams {
 170                        max_monthly_llm_usage_spending_in_cents,
 171                        model_request_overages_enabled: body.model_request_overages_enabled,
 172                        model_request_overages_spend_limit_in_cents,
 173                    },
 174                )
 175                .await?
 176        };
 177
 178    SnowflakeRow::new(
 179        "Billing Preferences Updated",
 180        Some(user.metrics_id),
 181        user.admin,
 182        None,
 183        json!({
 184            "user_id": user.id,
 185            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 186            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 187            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 188        }),
 189    )
 190    .write(&app.kinesis_client, &app.config.kinesis_stream)
 191    .await
 192    .log_err();
 193
 194    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 195
 196    Ok(Json(BillingPreferencesResponse {
 197        trial_started_at: billing_customer
 198            .and_then(|billing_customer| billing_customer.trial_started_at)
 199            .map(|trial_started_at| {
 200                trial_started_at
 201                    .and_utc()
 202                    .to_rfc3339_opts(SecondsFormat::Millis, true)
 203            }),
 204        max_monthly_llm_usage_spending_in_cents: billing_preferences
 205            .max_monthly_llm_usage_spending_in_cents,
 206        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 207        model_request_overages_spend_limit_in_cents: billing_preferences
 208            .model_request_overages_spend_limit_in_cents,
 209    }))
 210}
 211
 212#[derive(Debug, Deserialize)]
 213struct ListBillingSubscriptionsParams {
 214    github_user_id: i32,
 215}
 216
 217#[derive(Debug, Serialize)]
 218struct BillingSubscriptionJson {
 219    id: BillingSubscriptionId,
 220    name: String,
 221    status: StripeSubscriptionStatus,
 222    period: Option<BillingSubscriptionPeriodJson>,
 223    trial_end_at: Option<String>,
 224    cancel_at: Option<String>,
 225    /// Whether this subscription can be canceled.
 226    is_cancelable: bool,
 227}
 228
 229#[derive(Debug, Serialize)]
 230struct BillingSubscriptionPeriodJson {
 231    start_at: String,
 232    end_at: String,
 233}
 234
 235#[derive(Debug, Serialize)]
 236struct ListBillingSubscriptionsResponse {
 237    subscriptions: Vec<BillingSubscriptionJson>,
 238}
 239
 240async fn list_billing_subscriptions(
 241    Extension(app): Extension<Arc<AppState>>,
 242    Query(params): Query<ListBillingSubscriptionsParams>,
 243) -> Result<Json<ListBillingSubscriptionsResponse>> {
 244    let user = app
 245        .db
 246        .get_user_by_github_user_id(params.github_user_id)
 247        .await?
 248        .context("user not found")?;
 249
 250    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 251
 252    Ok(Json(ListBillingSubscriptionsResponse {
 253        subscriptions: subscriptions
 254            .into_iter()
 255            .map(|subscription| BillingSubscriptionJson {
 256                id: subscription.id,
 257                name: match subscription.kind {
 258                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 259                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
 260                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
 261                    None => "Zed LLM Usage".to_string(),
 262                },
 263                status: subscription.stripe_subscription_status,
 264                period: maybe!({
 265                    let start_at = subscription.current_period_start_at()?;
 266                    let end_at = subscription.current_period_end_at()?;
 267
 268                    Some(BillingSubscriptionPeriodJson {
 269                        start_at: start_at.to_rfc3339_opts(SecondsFormat::Millis, true),
 270                        end_at: end_at.to_rfc3339_opts(SecondsFormat::Millis, true),
 271                    })
 272                }),
 273                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
 274                    maybe!({
 275                        let end_at = subscription.stripe_current_period_end?;
 276                        let end_at = DateTime::from_timestamp(end_at, 0)?;
 277
 278                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
 279                    })
 280                } else {
 281                    None
 282                },
 283                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 284                    cancel_at
 285                        .and_utc()
 286                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 287                }),
 288                is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree)
 289                    && subscription.stripe_subscription_status.is_cancelable()
 290                    && subscription.stripe_cancel_at.is_none(),
 291            })
 292            .collect(),
 293    }))
 294}
 295
 296#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
 297#[serde(rename_all = "snake_case")]
 298enum ProductCode {
 299    ZedPro,
 300    ZedProTrial,
 301}
 302
 303#[derive(Debug, Deserialize)]
 304struct CreateBillingSubscriptionBody {
 305    github_user_id: i32,
 306    product: ProductCode,
 307}
 308
 309#[derive(Debug, Serialize)]
 310struct CreateBillingSubscriptionResponse {
 311    checkout_session_url: String,
 312}
 313
 314/// Initiates a Stripe Checkout session for creating a billing subscription.
 315async fn create_billing_subscription(
 316    Extension(app): Extension<Arc<AppState>>,
 317    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 318) -> Result<Json<CreateBillingSubscriptionResponse>> {
 319    let user = app
 320        .db
 321        .get_user_by_github_user_id(body.github_user_id)
 322        .await?
 323        .context("user not found")?;
 324
 325    let Some(stripe_billing) = app.stripe_billing.clone() else {
 326        log::error!("failed to retrieve Stripe billing object");
 327        Err(Error::http(
 328            StatusCode::NOT_IMPLEMENTED,
 329            "not supported".into(),
 330        ))?
 331    };
 332
 333    if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
 334        let is_checkout_allowed = body.product == ProductCode::ZedProTrial
 335            && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
 336
 337        if !is_checkout_allowed {
 338            return Err(Error::http(
 339                StatusCode::CONFLICT,
 340                "user already has an active subscription".into(),
 341            ));
 342        }
 343    }
 344
 345    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 346    if let Some(existing_billing_customer) = &existing_billing_customer {
 347        if existing_billing_customer.has_overdue_invoices {
 348            return Err(Error::http(
 349                StatusCode::PAYMENT_REQUIRED,
 350                "user has overdue invoices".into(),
 351            ));
 352        }
 353    }
 354
 355    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 356        let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
 357        if let Some(email) = user.email_address.as_deref() {
 358            stripe_billing
 359                .client()
 360                .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
 361                .await
 362                // Update of email address is best-effort - continue checkout even if it fails
 363                .context("error updating stripe customer email address")
 364                .log_err();
 365        }
 366        customer_id
 367    } else {
 368        stripe_billing
 369            .find_or_create_customer_by_email(user.email_address.as_deref())
 370            .await?
 371    };
 372
 373    let success_url = format!(
 374        "{}/account?checkout_complete=1",
 375        app.config.zed_dot_dev_url()
 376    );
 377
 378    let checkout_session_url = match body.product {
 379        ProductCode::ZedPro => {
 380            stripe_billing
 381                .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
 382                .await?
 383        }
 384        ProductCode::ZedProTrial => {
 385            if let Some(existing_billing_customer) = &existing_billing_customer {
 386                if existing_billing_customer.trial_started_at.is_some() {
 387                    return Err(Error::http(
 388                        StatusCode::FORBIDDEN,
 389                        "user already used free trial".into(),
 390                    ));
 391                }
 392            }
 393
 394            let feature_flags = app.db.get_user_flags(user.id).await?;
 395
 396            stripe_billing
 397                .checkout_with_zed_pro_trial(
 398                    &customer_id,
 399                    &user.github_login,
 400                    feature_flags,
 401                    &success_url,
 402                )
 403                .await?
 404        }
 405    };
 406
 407    Ok(Json(CreateBillingSubscriptionResponse {
 408        checkout_session_url,
 409    }))
 410}
 411
 412#[derive(Debug, PartialEq, Deserialize)]
 413#[serde(rename_all = "snake_case")]
 414enum ManageSubscriptionIntent {
 415    /// The user intends to manage their subscription.
 416    ///
 417    /// This will open the Stripe billing portal without putting the user in a specific flow.
 418    ManageSubscription,
 419    /// The user intends to update their payment method.
 420    UpdatePaymentMethod,
 421    /// The user intends to upgrade to Zed Pro.
 422    UpgradeToPro,
 423    /// The user intends to cancel their subscription.
 424    Cancel,
 425    /// The user intends to stop the cancellation of their subscription.
 426    StopCancellation,
 427}
 428
 429#[derive(Debug, Deserialize)]
 430struct ManageBillingSubscriptionBody {
 431    github_user_id: i32,
 432    intent: ManageSubscriptionIntent,
 433    /// The ID of the subscription to manage.
 434    subscription_id: BillingSubscriptionId,
 435    redirect_to: Option<String>,
 436}
 437
 438#[derive(Debug, Serialize)]
 439struct ManageBillingSubscriptionResponse {
 440    billing_portal_session_url: Option<String>,
 441}
 442
 443/// Initiates a Stripe customer portal session for managing a billing subscription.
 444async fn manage_billing_subscription(
 445    Extension(app): Extension<Arc<AppState>>,
 446    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 447) -> Result<Json<ManageBillingSubscriptionResponse>> {
 448    let user = app
 449        .db
 450        .get_user_by_github_user_id(body.github_user_id)
 451        .await?
 452        .context("user not found")?;
 453
 454    let Some(stripe_client) = app.real_stripe_client.clone() else {
 455        log::error!("failed to retrieve Stripe client");
 456        Err(Error::http(
 457            StatusCode::NOT_IMPLEMENTED,
 458            "not supported".into(),
 459        ))?
 460    };
 461
 462    let Some(stripe_billing) = app.stripe_billing.clone() else {
 463        log::error!("failed to retrieve Stripe billing object");
 464        Err(Error::http(
 465            StatusCode::NOT_IMPLEMENTED,
 466            "not supported".into(),
 467        ))?
 468    };
 469
 470    let customer = app
 471        .db
 472        .get_billing_customer_by_user_id(user.id)
 473        .await?
 474        .context("billing customer not found")?;
 475    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 476        .context("failed to parse customer ID")?;
 477
 478    let subscription = app
 479        .db
 480        .get_billing_subscription_by_id(body.subscription_id)
 481        .await?
 482        .context("subscription not found")?;
 483    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 484        .context("failed to parse subscription ID")?;
 485
 486    if body.intent == ManageSubscriptionIntent::StopCancellation {
 487        let updated_stripe_subscription = Subscription::update(
 488            &stripe_client,
 489            &subscription_id,
 490            stripe::UpdateSubscription {
 491                cancel_at_period_end: Some(false),
 492                ..Default::default()
 493            },
 494        )
 495        .await?;
 496
 497        app.db
 498            .update_billing_subscription(
 499                subscription.id,
 500                &UpdateBillingSubscriptionParams {
 501                    stripe_cancel_at: ActiveValue::set(
 502                        updated_stripe_subscription
 503                            .cancel_at
 504                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 505                            .map(|time| time.naive_utc()),
 506                    ),
 507                    ..Default::default()
 508                },
 509            )
 510            .await?;
 511
 512        return Ok(Json(ManageBillingSubscriptionResponse {
 513            billing_portal_session_url: None,
 514        }));
 515    }
 516
 517    let flow = match body.intent {
 518        ManageSubscriptionIntent::ManageSubscription => None,
 519        ManageSubscriptionIntent::UpgradeToPro => {
 520            let zed_pro_price_id: stripe::PriceId =
 521                stripe_billing.zed_pro_price_id().await?.try_into()?;
 522            let zed_free_price_id: stripe::PriceId =
 523                stripe_billing.zed_free_price_id().await?.try_into()?;
 524
 525            let stripe_subscription =
 526                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 527
 528            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 529                && stripe_subscription.items.data.iter().any(|item| {
 530                    item.price
 531                        .as_ref()
 532                        .map_or(false, |price| price.id == zed_pro_price_id)
 533                });
 534            if is_on_zed_pro_trial {
 535                let payment_methods = PaymentMethod::list(
 536                    &stripe_client,
 537                    &stripe::ListPaymentMethods {
 538                        customer: Some(stripe_subscription.customer.id()),
 539                        ..Default::default()
 540                    },
 541                )
 542                .await?;
 543
 544                let has_payment_method = !payment_methods.data.is_empty();
 545                if !has_payment_method {
 546                    return Err(Error::http(
 547                        StatusCode::BAD_REQUEST,
 548                        "missing payment method".into(),
 549                    ));
 550                }
 551
 552                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 553                Subscription::update(
 554                    &stripe_client,
 555                    &stripe_subscription.id,
 556                    stripe::UpdateSubscription {
 557                        trial_end: Some(stripe::Scheduled::now()),
 558                        ..Default::default()
 559                    },
 560                )
 561                .await?;
 562
 563                return Ok(Json(ManageBillingSubscriptionResponse {
 564                    billing_portal_session_url: None,
 565                }));
 566            }
 567
 568            let subscription_item_to_update = stripe_subscription
 569                .items
 570                .data
 571                .iter()
 572                .find_map(|item| {
 573                    let price = item.price.as_ref()?;
 574
 575                    if price.id == zed_free_price_id {
 576                        Some(item.id.clone())
 577                    } else {
 578                        None
 579                    }
 580                })
 581                .context("No subscription item to update")?;
 582
 583            Some(CreateBillingPortalSessionFlowData {
 584                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 585                subscription_update_confirm: Some(
 586                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 587                        subscription: subscription.stripe_subscription_id,
 588                        items: vec![
 589                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 590                                id: subscription_item_to_update.to_string(),
 591                                price: Some(zed_pro_price_id.to_string()),
 592                                quantity: Some(1),
 593                            },
 594                        ],
 595                        discounts: None,
 596                    },
 597                ),
 598                ..Default::default()
 599            })
 600        }
 601        ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
 602            type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
 603            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 604                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 605                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 606                    return_url: format!(
 607                        "{}{path}",
 608                        app.config.zed_dot_dev_url(),
 609                        path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
 610                    ),
 611                }),
 612                ..Default::default()
 613            }),
 614            ..Default::default()
 615        }),
 616        ManageSubscriptionIntent::Cancel => {
 617            if subscription.kind == Some(SubscriptionKind::ZedFree) {
 618                return Err(Error::http(
 619                    StatusCode::BAD_REQUEST,
 620                    "free subscription cannot be canceled".into(),
 621                ));
 622            }
 623
 624            Some(CreateBillingPortalSessionFlowData {
 625                type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 626                after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 627                    type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 628                    redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 629                        return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 630                    }),
 631                    ..Default::default()
 632                }),
 633                subscription_cancel: Some(
 634                    stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 635                        subscription: subscription.stripe_subscription_id,
 636                        retention: None,
 637                    },
 638                ),
 639                ..Default::default()
 640            })
 641        }
 642        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 643    };
 644
 645    let mut params = CreateBillingPortalSession::new(customer_id);
 646    params.flow_data = flow;
 647    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 648    params.return_url = Some(&return_url);
 649
 650    let session = BillingPortalSession::create(&stripe_client, params).await?;
 651
 652    Ok(Json(ManageBillingSubscriptionResponse {
 653        billing_portal_session_url: Some(session.url),
 654    }))
 655}
 656
 657#[derive(Debug, Deserialize)]
 658struct SyncBillingSubscriptionBody {
 659    github_user_id: i32,
 660}
 661
 662#[derive(Debug, Serialize)]
 663struct SyncBillingSubscriptionResponse {
 664    stripe_customer_id: String,
 665}
 666
 667async fn sync_billing_subscription(
 668    Extension(app): Extension<Arc<AppState>>,
 669    extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
 670) -> Result<Json<SyncBillingSubscriptionResponse>> {
 671    let Some(stripe_client) = app.stripe_client.clone() else {
 672        log::error!("failed to retrieve Stripe client");
 673        Err(Error::http(
 674            StatusCode::NOT_IMPLEMENTED,
 675            "not supported".into(),
 676        ))?
 677    };
 678
 679    let user = app
 680        .db
 681        .get_user_by_github_user_id(body.github_user_id)
 682        .await?
 683        .context("user not found")?;
 684
 685    let billing_customer = app
 686        .db
 687        .get_billing_customer_by_user_id(user.id)
 688        .await?
 689        .context("billing customer not found")?;
 690    let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
 691
 692    let subscriptions = stripe_client
 693        .list_subscriptions_for_customer(&stripe_customer_id)
 694        .await?;
 695
 696    for subscription in subscriptions {
 697        let subscription_id = subscription.id.clone();
 698
 699        sync_subscription(&app, &stripe_client, subscription)
 700            .await
 701            .with_context(|| {
 702                format!(
 703                    "failed to sync subscription {subscription_id} for user {}",
 704                    user.id,
 705                )
 706            })?;
 707    }
 708
 709    Ok(Json(SyncBillingSubscriptionResponse {
 710        stripe_customer_id: billing_customer.stripe_customer_id.clone(),
 711    }))
 712}
 713
 714/// The amount of time we wait in between each poll of Stripe events.
 715///
 716/// This value should strike a balance between:
 717///   1. Being short enough that we update quickly when something in Stripe changes
 718///   2. Being long enough that we don't eat into our rate limits.
 719///
 720/// As a point of reference, the Sequin folks say they have this at **500ms**:
 721///
 722/// > We poll the Stripe /events endpoint every 500ms per account
 723/// >
 724/// > — https://blog.sequinstream.com/events-not-webhooks/
 725const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 726
 727/// The maximum number of events to return per page.
 728///
 729/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 730///
 731/// > Limit can range between 1 and 100, and the default is 10.
 732const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 733
 734/// The number of pages consisting entirely of already-processed events that we
 735/// will see before we stop retrieving events.
 736///
 737/// This is used to prevent over-fetching the Stripe events API for events we've
 738/// already seen and processed.
 739const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 740
 741/// Polls the Stripe events API periodically to reconcile the records in our
 742/// database with the data in Stripe.
 743pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 744    let Some(real_stripe_client) = app.real_stripe_client.clone() else {
 745        log::warn!("failed to retrieve Stripe client");
 746        return;
 747    };
 748    let Some(stripe_client) = app.stripe_client.clone() else {
 749        log::warn!("failed to retrieve Stripe client");
 750        return;
 751    };
 752
 753    let executor = app.executor.clone();
 754    executor.spawn_detached({
 755        let executor = executor.clone();
 756        async move {
 757            loop {
 758                poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
 759                    .await
 760                    .log_err();
 761
 762                executor.sleep(POLL_EVENTS_INTERVAL).await;
 763            }
 764        }
 765    });
 766}
 767
 768async fn poll_stripe_events(
 769    app: &Arc<AppState>,
 770    rpc_server: &Arc<Server>,
 771    stripe_client: &Arc<dyn StripeClient>,
 772    real_stripe_client: &stripe::Client,
 773) -> anyhow::Result<()> {
 774    fn event_type_to_string(event_type: EventType) -> String {
 775        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 776        // so we need to unquote it.
 777        event_type.to_string().trim_matches('"').to_string()
 778    }
 779
 780    let event_types = [
 781        EventType::CustomerCreated,
 782        EventType::CustomerUpdated,
 783        EventType::CustomerSubscriptionCreated,
 784        EventType::CustomerSubscriptionUpdated,
 785        EventType::CustomerSubscriptionPaused,
 786        EventType::CustomerSubscriptionResumed,
 787        EventType::CustomerSubscriptionDeleted,
 788    ]
 789    .into_iter()
 790    .map(event_type_to_string)
 791    .collect::<Vec<_>>();
 792
 793    let mut pages_of_already_processed_events = 0;
 794    let mut unprocessed_events = Vec::new();
 795
 796    log::info!(
 797        "Stripe events: starting retrieval for {}",
 798        event_types.join(", ")
 799    );
 800    let mut params = ListEvents::new();
 801    params.types = Some(event_types.clone());
 802    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 803
 804    let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
 805        .await?
 806        .paginate(params);
 807
 808    loop {
 809        let processed_event_ids = {
 810            let event_ids = event_pages
 811                .page
 812                .data
 813                .iter()
 814                .map(|event| event.id.as_str())
 815                .collect::<Vec<_>>();
 816            app.db
 817                .get_processed_stripe_events_by_event_ids(&event_ids)
 818                .await?
 819                .into_iter()
 820                .map(|event| event.stripe_event_id)
 821                .collect::<Vec<_>>()
 822        };
 823
 824        let mut processed_events_in_page = 0;
 825        let events_in_page = event_pages.page.data.len();
 826        for event in &event_pages.page.data {
 827            if processed_event_ids.contains(&event.id.to_string()) {
 828                processed_events_in_page += 1;
 829                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 830            } else {
 831                unprocessed_events.push(event.clone());
 832            }
 833        }
 834
 835        if processed_events_in_page == events_in_page {
 836            pages_of_already_processed_events += 1;
 837        }
 838
 839        if event_pages.page.has_more {
 840            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 841            {
 842                log::info!(
 843                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 844                );
 845                break;
 846            } else {
 847                log::info!("Stripe events: retrieving next page");
 848                event_pages = event_pages.next(&real_stripe_client).await?;
 849            }
 850        } else {
 851            break;
 852        }
 853    }
 854
 855    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 856
 857    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 858    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 859
 860    for event in unprocessed_events {
 861        let event_id = event.id.clone();
 862        let processed_event_params = CreateProcessedStripeEventParams {
 863            stripe_event_id: event.id.to_string(),
 864            stripe_event_type: event_type_to_string(event.type_),
 865            stripe_event_created_timestamp: event.created,
 866        };
 867
 868        // If the event has happened too far in the past, we don't want to
 869        // process it and risk overwriting other more-recent updates.
 870        //
 871        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 872        let one_day = Duration::from_secs(24 * 60 * 60);
 873        let a_day_ago = Utc::now() - one_day;
 874        if a_day_ago.timestamp() > event.created {
 875            log::info!(
 876                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 877                event_id
 878            );
 879            app.db
 880                .create_processed_stripe_event(&processed_event_params)
 881                .await?;
 882
 883            continue;
 884        }
 885
 886        let process_result = match event.type_ {
 887            EventType::CustomerCreated | EventType::CustomerUpdated => {
 888                handle_customer_event(app, real_stripe_client, event).await
 889            }
 890            EventType::CustomerSubscriptionCreated
 891            | EventType::CustomerSubscriptionUpdated
 892            | EventType::CustomerSubscriptionPaused
 893            | EventType::CustomerSubscriptionResumed
 894            | EventType::CustomerSubscriptionDeleted => {
 895                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 896            }
 897            _ => Ok(()),
 898        };
 899
 900        if let Some(()) = process_result
 901            .with_context(|| format!("failed to process event {event_id} successfully"))
 902            .log_err()
 903        {
 904            app.db
 905                .create_processed_stripe_event(&processed_event_params)
 906                .await?;
 907        }
 908    }
 909
 910    Ok(())
 911}
 912
 913async fn handle_customer_event(
 914    app: &Arc<AppState>,
 915    _stripe_client: &stripe::Client,
 916    event: stripe::Event,
 917) -> anyhow::Result<()> {
 918    let EventObject::Customer(customer) = event.data.object else {
 919        bail!("unexpected event payload for {}", event.id);
 920    };
 921
 922    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 923
 924    let Some(email) = customer.email else {
 925        log::info!("Stripe customer has no email: skipping");
 926        return Ok(());
 927    };
 928
 929    let Some(user) = app.db.get_user_by_email(&email).await? else {
 930        log::info!("no user found for email: skipping");
 931        return Ok(());
 932    };
 933
 934    if let Some(existing_customer) = app
 935        .db
 936        .get_billing_customer_by_stripe_customer_id(&customer.id)
 937        .await?
 938    {
 939        app.db
 940            .update_billing_customer(
 941                existing_customer.id,
 942                &UpdateBillingCustomerParams {
 943                    // For now we just leave the information as-is, as it is not
 944                    // likely to change.
 945                    ..Default::default()
 946                },
 947            )
 948            .await?;
 949    } else {
 950        app.db
 951            .create_billing_customer(&CreateBillingCustomerParams {
 952                user_id: user.id,
 953                stripe_customer_id: customer.id.to_string(),
 954            })
 955            .await?;
 956    }
 957
 958    Ok(())
 959}
 960
 961async fn sync_subscription(
 962    app: &Arc<AppState>,
 963    stripe_client: &Arc<dyn StripeClient>,
 964    subscription: StripeSubscription,
 965) -> anyhow::Result<billing_customer::Model> {
 966    let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
 967        stripe_billing
 968            .determine_subscription_kind(&subscription)
 969            .await
 970    } else {
 971        None
 972    };
 973
 974    let billing_customer =
 975        find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
 976            .await?
 977            .context("billing customer not found")?;
 978
 979    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
 980        if subscription.status == SubscriptionStatus::Trialing {
 981            let current_period_start =
 982                DateTime::from_timestamp(subscription.current_period_start, 0)
 983                    .context("No trial subscription period start")?;
 984
 985            app.db
 986                .update_billing_customer(
 987                    billing_customer.id,
 988                    &UpdateBillingCustomerParams {
 989                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
 990                        ..Default::default()
 991                    },
 992                )
 993                .await?;
 994        }
 995    }
 996
 997    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 998        && subscription
 999            .cancellation_details
1000            .as_ref()
1001            .and_then(|details| details.reason)
1002            .map_or(false, |reason| {
1003                reason == StripeCancellationDetailsReason::PaymentFailed
1004            });
1005
1006    if was_canceled_due_to_payment_failure {
1007        app.db
1008            .update_billing_customer(
1009                billing_customer.id,
1010                &UpdateBillingCustomerParams {
1011                    has_overdue_invoices: ActiveValue::set(true),
1012                    ..Default::default()
1013                },
1014            )
1015            .await?;
1016    }
1017
1018    if let Some(existing_subscription) = app
1019        .db
1020        .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
1021        .await?
1022    {
1023        app.db
1024            .update_billing_subscription(
1025                existing_subscription.id,
1026                &UpdateBillingSubscriptionParams {
1027                    billing_customer_id: ActiveValue::set(billing_customer.id),
1028                    kind: ActiveValue::set(subscription_kind),
1029                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1030                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1031                    stripe_cancel_at: ActiveValue::set(
1032                        subscription
1033                            .cancel_at
1034                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1035                            .map(|time| time.naive_utc()),
1036                    ),
1037                    stripe_cancellation_reason: ActiveValue::set(
1038                        subscription
1039                            .cancellation_details
1040                            .and_then(|details| details.reason)
1041                            .map(|reason| reason.into()),
1042                    ),
1043                    stripe_current_period_start: ActiveValue::set(Some(
1044                        subscription.current_period_start,
1045                    )),
1046                    stripe_current_period_end: ActiveValue::set(Some(
1047                        subscription.current_period_end,
1048                    )),
1049                },
1050            )
1051            .await?;
1052    } else {
1053        if let Some(existing_subscription) = app
1054            .db
1055            .get_active_billing_subscription(billing_customer.user_id)
1056            .await?
1057        {
1058            if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1059                && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1060            {
1061                let stripe_subscription_id = StripeSubscriptionId(
1062                    existing_subscription.stripe_subscription_id.clone().into(),
1063                );
1064
1065                stripe_client
1066                    .cancel_subscription(&stripe_subscription_id)
1067                    .await?;
1068            } else {
1069                // If the user already has an active billing subscription, ignore the
1070                // event and return an `Ok` to signal that it was processed
1071                // successfully.
1072                //
1073                // There is the possibility that this could cause us to not create a
1074                // subscription in the following scenario:
1075                //
1076                //   1. User has an active subscription A
1077                //   2. User cancels subscription A
1078                //   3. User creates a new subscription B
1079                //   4. We process the new subscription B before the cancellation of subscription A
1080                //   5. User ends up with no subscriptions
1081                //
1082                // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1083
1084                log::info!(
1085                    "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1086                    user_id = billing_customer.user_id,
1087                    subscription_id = subscription.id
1088                );
1089                return Ok(billing_customer);
1090            }
1091        }
1092
1093        app.db
1094            .create_billing_subscription(&CreateBillingSubscriptionParams {
1095                billing_customer_id: billing_customer.id,
1096                kind: subscription_kind,
1097                stripe_subscription_id: subscription.id.to_string(),
1098                stripe_subscription_status: subscription.status.into(),
1099                stripe_cancellation_reason: subscription
1100                    .cancellation_details
1101                    .and_then(|details| details.reason)
1102                    .map(|reason| reason.into()),
1103                stripe_current_period_start: Some(subscription.current_period_start),
1104                stripe_current_period_end: Some(subscription.current_period_end),
1105            })
1106            .await?;
1107    }
1108
1109    if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1110        if subscription.status == SubscriptionStatus::Canceled
1111            || subscription.status == SubscriptionStatus::Paused
1112        {
1113            let already_has_active_billing_subscription = app
1114                .db
1115                .has_active_billing_subscription(billing_customer.user_id)
1116                .await?;
1117            if !already_has_active_billing_subscription {
1118                let stripe_customer_id =
1119                    StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1120
1121                stripe_billing
1122                    .subscribe_to_zed_free(stripe_customer_id)
1123                    .await?;
1124            }
1125        }
1126    }
1127
1128    Ok(billing_customer)
1129}
1130
1131async fn handle_customer_subscription_event(
1132    app: &Arc<AppState>,
1133    rpc_server: &Arc<Server>,
1134    stripe_client: &Arc<dyn StripeClient>,
1135    event: stripe::Event,
1136) -> anyhow::Result<()> {
1137    let EventObject::Subscription(subscription) = event.data.object else {
1138        bail!("unexpected event payload for {}", event.id);
1139    };
1140
1141    log::info!("handling Stripe {} event: {}", event.type_, event.id);
1142
1143    let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
1144
1145    // When the user's subscription changes, push down any changes to their plan.
1146    rpc_server
1147        .update_plan_for_user(billing_customer.user_id)
1148        .await
1149        .trace_err();
1150
1151    // When the user's subscription changes, we want to refresh their LLM tokens
1152    // to either grant/revoke access.
1153    rpc_server
1154        .refresh_llm_tokens_for_user(billing_customer.user_id)
1155        .await;
1156
1157    Ok(())
1158}
1159
1160#[derive(Debug, Deserialize)]
1161struct GetCurrentUsageParams {
1162    github_user_id: i32,
1163}
1164
1165#[derive(Debug, Serialize)]
1166struct UsageCounts {
1167    pub used: i32,
1168    pub limit: Option<i32>,
1169    pub remaining: Option<i32>,
1170}
1171
1172#[derive(Debug, Serialize)]
1173struct ModelRequestUsage {
1174    pub model: String,
1175    pub mode: CompletionMode,
1176    pub requests: i32,
1177}
1178
1179#[derive(Debug, Serialize)]
1180struct CurrentUsage {
1181    pub model_requests: UsageCounts,
1182    pub model_request_usage: Vec<ModelRequestUsage>,
1183    pub edit_predictions: UsageCounts,
1184}
1185
1186#[derive(Debug, Default, Serialize)]
1187struct GetCurrentUsageResponse {
1188    pub plan: String,
1189    pub current_usage: Option<CurrentUsage>,
1190}
1191
1192async fn get_current_usage(
1193    Extension(app): Extension<Arc<AppState>>,
1194    Query(params): Query<GetCurrentUsageParams>,
1195) -> Result<Json<GetCurrentUsageResponse>> {
1196    let user = app
1197        .db
1198        .get_user_by_github_user_id(params.github_user_id)
1199        .await?
1200        .context("user not found")?;
1201
1202    let feature_flags = app.db.get_user_flags(user.id).await?;
1203    let has_extended_trial = feature_flags
1204        .iter()
1205        .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1206
1207    let Some(llm_db) = app.llm_db.clone() else {
1208        return Err(Error::http(
1209            StatusCode::NOT_IMPLEMENTED,
1210            "LLM database not available".into(),
1211        ));
1212    };
1213
1214    let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1215        return Ok(Json(GetCurrentUsageResponse::default()));
1216    };
1217
1218    let subscription_period = maybe!({
1219        let period_start_at = subscription.current_period_start_at()?;
1220        let period_end_at = subscription.current_period_end_at()?;
1221
1222        Some((period_start_at, period_end_at))
1223    });
1224
1225    let Some((period_start_at, period_end_at)) = subscription_period else {
1226        return Ok(Json(GetCurrentUsageResponse::default()));
1227    };
1228
1229    let usage = llm_db
1230        .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1231        .await?;
1232
1233    let plan = subscription
1234        .kind
1235        .map(Into::into)
1236        .unwrap_or(zed_llm_client::Plan::ZedFree);
1237
1238    let model_requests_limit = match plan.model_requests_limit() {
1239        zed_llm_client::UsageLimit::Limited(limit) => {
1240            let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1241                1_000
1242            } else {
1243                limit
1244            };
1245
1246            Some(limit)
1247        }
1248        zed_llm_client::UsageLimit::Unlimited => None,
1249    };
1250
1251    let edit_predictions_limit = match plan.edit_predictions_limit() {
1252        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1253        zed_llm_client::UsageLimit::Unlimited => None,
1254    };
1255
1256    let Some(usage) = usage else {
1257        return Ok(Json(GetCurrentUsageResponse {
1258            plan: plan.as_str().to_string(),
1259            current_usage: Some(CurrentUsage {
1260                model_requests: UsageCounts {
1261                    used: 0,
1262                    limit: model_requests_limit,
1263                    remaining: model_requests_limit,
1264                },
1265                model_request_usage: Vec::new(),
1266                edit_predictions: UsageCounts {
1267                    used: 0,
1268                    limit: edit_predictions_limit,
1269                    remaining: edit_predictions_limit,
1270                },
1271            }),
1272        }));
1273    };
1274
1275    let subscription_usage_meters = llm_db
1276        .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1277        .await?;
1278
1279    let model_request_usage = subscription_usage_meters
1280        .into_iter()
1281        .filter_map(|(usage_meter, _usage)| {
1282            let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1283
1284            Some(ModelRequestUsage {
1285                model: model.name.clone(),
1286                mode: usage_meter.mode,
1287                requests: usage_meter.requests,
1288            })
1289        })
1290        .collect::<Vec<_>>();
1291
1292    Ok(Json(GetCurrentUsageResponse {
1293        plan: plan.as_str().to_string(),
1294        current_usage: Some(CurrentUsage {
1295            model_requests: UsageCounts {
1296                used: usage.model_requests,
1297                limit: model_requests_limit,
1298                remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1299            },
1300            model_request_usage,
1301            edit_predictions: UsageCounts {
1302                used: usage.edit_predictions,
1303                limit: edit_predictions_limit,
1304                remaining: edit_predictions_limit
1305                    .map(|limit| (limit - usage.edit_predictions).max(0)),
1306            },
1307        }),
1308    }))
1309}
1310
1311impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1312    fn from(value: SubscriptionStatus) -> Self {
1313        match value {
1314            SubscriptionStatus::Incomplete => Self::Incomplete,
1315            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1316            SubscriptionStatus::Trialing => Self::Trialing,
1317            SubscriptionStatus::Active => Self::Active,
1318            SubscriptionStatus::PastDue => Self::PastDue,
1319            SubscriptionStatus::Canceled => Self::Canceled,
1320            SubscriptionStatus::Unpaid => Self::Unpaid,
1321            SubscriptionStatus::Paused => Self::Paused,
1322        }
1323    }
1324}
1325
1326impl From<CancellationDetailsReason> for StripeCancellationReason {
1327    fn from(value: CancellationDetailsReason) -> Self {
1328        match value {
1329            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1330            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1331            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1332        }
1333    }
1334}
1335
1336/// Finds or creates a billing customer using the provided customer.
1337pub async fn find_or_create_billing_customer(
1338    app: &Arc<AppState>,
1339    stripe_client: &dyn StripeClient,
1340    customer_id: &StripeCustomerId,
1341) -> anyhow::Result<Option<billing_customer::Model>> {
1342    // If we already have a billing customer record associated with the Stripe customer,
1343    // there's nothing more we need to do.
1344    if let Some(billing_customer) = app
1345        .db
1346        .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
1347        .await?
1348    {
1349        return Ok(Some(billing_customer));
1350    }
1351
1352    let customer = stripe_client.get_customer(customer_id).await?;
1353
1354    let Some(email) = customer.email else {
1355        return Ok(None);
1356    };
1357
1358    let Some(user) = app.db.get_user_by_email(&email).await? else {
1359        return Ok(None);
1360    };
1361
1362    let billing_customer = app
1363        .db
1364        .create_billing_customer(&CreateBillingCustomerParams {
1365            user_id: user.id,
1366            stripe_customer_id: customer.id.to_string(),
1367        })
1368        .await?;
1369
1370    Ok(Some(billing_customer))
1371}
1372
1373const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1374
1375pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1376    let Some(stripe_billing) = app.stripe_billing.clone() else {
1377        log::warn!("failed to retrieve Stripe billing object");
1378        return;
1379    };
1380    let Some(llm_db) = app.llm_db.clone() else {
1381        log::warn!("failed to retrieve LLM database");
1382        return;
1383    };
1384
1385    let executor = app.executor.clone();
1386    executor.spawn_detached({
1387        let executor = executor.clone();
1388        async move {
1389            loop {
1390                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1391                    .await
1392                    .context("failed to sync LLM request usage to Stripe")
1393                    .trace_err();
1394                executor
1395                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1396                    .await;
1397            }
1398        }
1399    });
1400}
1401
1402async fn sync_model_request_usage_with_stripe(
1403    app: &Arc<AppState>,
1404    llm_db: &Arc<LlmDatabase>,
1405    stripe_billing: &Arc<StripeBilling>,
1406) -> anyhow::Result<()> {
1407    log::info!("Stripe usage sync: Starting");
1408    let started_at = Utc::now();
1409
1410    let staff_users = app.db.get_staff_users().await?;
1411    let staff_user_ids = staff_users
1412        .iter()
1413        .map(|user| user.id)
1414        .collect::<HashSet<UserId>>();
1415
1416    let usage_meters = llm_db
1417        .get_current_subscription_usage_meters(Utc::now())
1418        .await?;
1419    let usage_meters = usage_meters
1420        .into_iter()
1421        .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1422        .collect::<Vec<_>>();
1423    let user_ids = usage_meters
1424        .iter()
1425        .map(|(_, usage)| usage.user_id)
1426        .collect::<HashSet<UserId>>();
1427    let billing_subscriptions = app
1428        .db
1429        .get_active_zed_pro_billing_subscriptions(user_ids)
1430        .await?;
1431
1432    let claude_sonnet_4 = stripe_billing
1433        .find_price_by_lookup_key("claude-sonnet-4-requests")
1434        .await?;
1435    let claude_sonnet_4_max = stripe_billing
1436        .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1437        .await?;
1438    let claude_opus_4 = stripe_billing
1439        .find_price_by_lookup_key("claude-opus-4-requests")
1440        .await?;
1441    let claude_opus_4_max = stripe_billing
1442        .find_price_by_lookup_key("claude-opus-4-requests-max")
1443        .await?;
1444    let claude_3_5_sonnet = stripe_billing
1445        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1446        .await?;
1447    let claude_3_7_sonnet = stripe_billing
1448        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1449        .await?;
1450    let claude_3_7_sonnet_max = stripe_billing
1451        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1452        .await?;
1453
1454    let usage_meter_count = usage_meters.len();
1455
1456    log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
1457
1458    for (usage_meter, usage) in usage_meters {
1459        maybe!(async {
1460            let Some((billing_customer, billing_subscription)) =
1461                billing_subscriptions.get(&usage.user_id)
1462            else {
1463                bail!(
1464                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1465                    usage.user_id
1466                );
1467            };
1468
1469            let stripe_customer_id =
1470                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1471            let stripe_subscription_id =
1472                StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1473
1474            let model = llm_db.model_by_id(usage_meter.model_id)?;
1475
1476            let (price, meter_event_name) = match model.name.as_str() {
1477                "claude-opus-4" => match usage_meter.mode {
1478                    CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1479                    CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1480                },
1481                "claude-sonnet-4" => match usage_meter.mode {
1482                    CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1483                    CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
1484                },
1485                "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1486                "claude-3-7-sonnet" => match usage_meter.mode {
1487                    CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1488                    CompletionMode::Max => {
1489                        (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1490                    }
1491                },
1492                model_name => {
1493                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1494                }
1495            };
1496
1497            stripe_billing
1498                .subscribe_to_price(&stripe_subscription_id, price)
1499                .await?;
1500            stripe_billing
1501                .bill_model_request_usage(
1502                    &stripe_customer_id,
1503                    meter_event_name,
1504                    usage_meter.requests,
1505                )
1506                .await?;
1507
1508            Ok(())
1509        })
1510        .await
1511        .log_err();
1512    }
1513
1514    log::info!(
1515        "Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
1516        Utc::now() - started_at
1517    );
1518
1519    Ok(())
1520}