billing.rs

   1use anyhow::{Context, anyhow, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  19    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  20    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
  21    EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
  22    SubscriptionStatus,
  23};
  24use util::{ResultExt, maybe};
  25
  26use crate::api::events::SnowflakeRow;
  27use crate::db::billing_subscription::{
  28    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  29};
  30use crate::llm::db::subscription_usage_meter::CompletionMode;
  31use crate::llm::{
  32    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
  33};
  34use crate::rpc::{ResultExt as _, Server};
  35use crate::{AppState, Cents, Error, Result};
  36use crate::{db::UserId, llm::db::LlmDatabase};
  37use crate::{
  38    db::{
  39        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  40        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  41        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  42    },
  43    stripe_billing::StripeBilling,
  44};
  45
  46pub fn router() -> Router {
  47    Router::new()
  48        .route(
  49            "/billing/preferences",
  50            get(get_billing_preferences).put(update_billing_preferences),
  51        )
  52        .route(
  53            "/billing/subscriptions",
  54            get(list_billing_subscriptions).post(create_billing_subscription),
  55        )
  56        .route(
  57            "/billing/subscriptions/manage",
  58            post(manage_billing_subscription),
  59        )
  60        .route(
  61            "/billing/subscriptions/migrate",
  62            post(migrate_to_new_billing),
  63        )
  64        .route(
  65            "/billing/subscriptions/sync",
  66            post(sync_billing_subscription),
  67        )
  68        .route("/billing/monthly_spend", get(get_monthly_spend))
  69        .route("/billing/usage", get(get_current_usage))
  70}
  71
  72#[derive(Debug, Deserialize)]
  73struct GetBillingPreferencesParams {
  74    github_user_id: i32,
  75}
  76
  77#[derive(Debug, Serialize)]
  78struct BillingPreferencesResponse {
  79    trial_started_at: Option<String>,
  80    max_monthly_llm_usage_spending_in_cents: i32,
  81    model_request_overages_enabled: bool,
  82    model_request_overages_spend_limit_in_cents: i32,
  83}
  84
  85async fn get_billing_preferences(
  86    Extension(app): Extension<Arc<AppState>>,
  87    Query(params): Query<GetBillingPreferencesParams>,
  88) -> Result<Json<BillingPreferencesResponse>> {
  89    let user = app
  90        .db
  91        .get_user_by_github_user_id(params.github_user_id)
  92        .await?
  93        .ok_or_else(|| anyhow!("user not found"))?;
  94
  95    let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
  96    let preferences = app.db.get_billing_preferences(user.id).await?;
  97
  98    Ok(Json(BillingPreferencesResponse {
  99        trial_started_at: billing_customer
 100            .and_then(|billing_customer| billing_customer.trial_started_at)
 101            .map(|trial_started_at| {
 102                trial_started_at
 103                    .and_utc()
 104                    .to_rfc3339_opts(SecondsFormat::Millis, true)
 105            }),
 106        max_monthly_llm_usage_spending_in_cents: preferences
 107            .as_ref()
 108            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
 109                preferences.max_monthly_llm_usage_spending_in_cents
 110            }),
 111        model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
 112            preferences.model_request_overages_enabled
 113        }),
 114        model_request_overages_spend_limit_in_cents: preferences
 115            .as_ref()
 116            .map_or(0, |preferences| {
 117                preferences.model_request_overages_spend_limit_in_cents
 118            }),
 119    }))
 120}
 121
 122#[derive(Debug, Deserialize)]
 123struct UpdateBillingPreferencesBody {
 124    github_user_id: i32,
 125    #[serde(default)]
 126    max_monthly_llm_usage_spending_in_cents: i32,
 127    #[serde(default)]
 128    model_request_overages_enabled: bool,
 129    #[serde(default)]
 130    model_request_overages_spend_limit_in_cents: i32,
 131}
 132
 133async fn update_billing_preferences(
 134    Extension(app): Extension<Arc<AppState>>,
 135    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
 136    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 137) -> Result<Json<BillingPreferencesResponse>> {
 138    let user = app
 139        .db
 140        .get_user_by_github_user_id(body.github_user_id)
 141        .await?
 142        .ok_or_else(|| anyhow!("user not found"))?;
 143
 144    let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 145
 146    let max_monthly_llm_usage_spending_in_cents =
 147        body.max_monthly_llm_usage_spending_in_cents.max(0);
 148    let model_request_overages_spend_limit_in_cents =
 149        body.model_request_overages_spend_limit_in_cents.max(0);
 150
 151    let billing_preferences =
 152        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 153            app.db
 154                .update_billing_preferences(
 155                    user.id,
 156                    &UpdateBillingPreferencesParams {
 157                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 158                            max_monthly_llm_usage_spending_in_cents,
 159                        ),
 160                        model_request_overages_enabled: ActiveValue::set(
 161                            body.model_request_overages_enabled,
 162                        ),
 163                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 164                            model_request_overages_spend_limit_in_cents,
 165                        ),
 166                    },
 167                )
 168                .await?
 169        } else {
 170            app.db
 171                .create_billing_preferences(
 172                    user.id,
 173                    &crate::db::CreateBillingPreferencesParams {
 174                        max_monthly_llm_usage_spending_in_cents,
 175                        model_request_overages_enabled: body.model_request_overages_enabled,
 176                        model_request_overages_spend_limit_in_cents,
 177                    },
 178                )
 179                .await?
 180        };
 181
 182    SnowflakeRow::new(
 183        "Billing Preferences Updated",
 184        Some(user.metrics_id),
 185        user.admin,
 186        None,
 187        json!({
 188            "user_id": user.id,
 189            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 190            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 191            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 192        }),
 193    )
 194    .write(&app.kinesis_client, &app.config.kinesis_stream)
 195    .await
 196    .log_err();
 197
 198    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 199
 200    Ok(Json(BillingPreferencesResponse {
 201        trial_started_at: billing_customer
 202            .and_then(|billing_customer| billing_customer.trial_started_at)
 203            .map(|trial_started_at| {
 204                trial_started_at
 205                    .and_utc()
 206                    .to_rfc3339_opts(SecondsFormat::Millis, true)
 207            }),
 208        max_monthly_llm_usage_spending_in_cents: billing_preferences
 209            .max_monthly_llm_usage_spending_in_cents,
 210        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 211        model_request_overages_spend_limit_in_cents: billing_preferences
 212            .model_request_overages_spend_limit_in_cents,
 213    }))
 214}
 215
 216#[derive(Debug, Deserialize)]
 217struct ListBillingSubscriptionsParams {
 218    github_user_id: i32,
 219}
 220
 221#[derive(Debug, Serialize)]
 222struct BillingSubscriptionJson {
 223    id: BillingSubscriptionId,
 224    name: String,
 225    status: StripeSubscriptionStatus,
 226    trial_end_at: Option<String>,
 227    cancel_at: Option<String>,
 228    /// Whether this subscription can be canceled.
 229    is_cancelable: bool,
 230}
 231
 232#[derive(Debug, Serialize)]
 233struct ListBillingSubscriptionsResponse {
 234    subscriptions: Vec<BillingSubscriptionJson>,
 235}
 236
 237async fn list_billing_subscriptions(
 238    Extension(app): Extension<Arc<AppState>>,
 239    Query(params): Query<ListBillingSubscriptionsParams>,
 240) -> Result<Json<ListBillingSubscriptionsResponse>> {
 241    let user = app
 242        .db
 243        .get_user_by_github_user_id(params.github_user_id)
 244        .await?
 245        .ok_or_else(|| anyhow!("user not found"))?;
 246
 247    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 248
 249    Ok(Json(ListBillingSubscriptionsResponse {
 250        subscriptions: subscriptions
 251            .into_iter()
 252            .map(|subscription| BillingSubscriptionJson {
 253                id: subscription.id,
 254                name: match subscription.kind {
 255                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 256                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
 257                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
 258                    None => "Zed LLM Usage".to_string(),
 259                },
 260                status: subscription.stripe_subscription_status,
 261                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
 262                    maybe!({
 263                        let end_at = subscription.stripe_current_period_end?;
 264                        let end_at = DateTime::from_timestamp(end_at, 0)?;
 265
 266                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
 267                    })
 268                } else {
 269                    None
 270                },
 271                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 272                    cancel_at
 273                        .and_utc()
 274                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 275                }),
 276                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
 277                    && subscription.stripe_cancel_at.is_none(),
 278            })
 279            .collect(),
 280    }))
 281}
 282
 283#[derive(Debug, Clone, Copy, Deserialize)]
 284#[serde(rename_all = "snake_case")]
 285enum ProductCode {
 286    ZedPro,
 287    ZedProTrial,
 288    ZedFree,
 289}
 290
 291#[derive(Debug, Deserialize)]
 292struct CreateBillingSubscriptionBody {
 293    github_user_id: i32,
 294    product: ProductCode,
 295}
 296
 297#[derive(Debug, Serialize)]
 298struct CreateBillingSubscriptionResponse {
 299    checkout_session_url: String,
 300}
 301
 302/// Initiates a Stripe Checkout session for creating a billing subscription.
 303async fn create_billing_subscription(
 304    Extension(app): Extension<Arc<AppState>>,
 305    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 306) -> Result<Json<CreateBillingSubscriptionResponse>> {
 307    let user = app
 308        .db
 309        .get_user_by_github_user_id(body.github_user_id)
 310        .await?
 311        .ok_or_else(|| anyhow!("user not found"))?;
 312
 313    let Some(stripe_client) = app.stripe_client.clone() else {
 314        log::error!("failed to retrieve Stripe client");
 315        Err(Error::http(
 316            StatusCode::NOT_IMPLEMENTED,
 317            "not supported".into(),
 318        ))?
 319    };
 320    let Some(stripe_billing) = app.stripe_billing.clone() else {
 321        log::error!("failed to retrieve Stripe billing object");
 322        Err(Error::http(
 323            StatusCode::NOT_IMPLEMENTED,
 324            "not supported".into(),
 325        ))?
 326    };
 327
 328    if app.db.has_active_billing_subscription(user.id).await? {
 329        return Err(Error::http(
 330            StatusCode::CONFLICT,
 331            "user already has an active subscription".into(),
 332        ));
 333    }
 334
 335    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 336    if let Some(existing_billing_customer) = &existing_billing_customer {
 337        if existing_billing_customer.has_overdue_invoices {
 338            return Err(Error::http(
 339                StatusCode::PAYMENT_REQUIRED,
 340                "user has overdue invoices".into(),
 341            ));
 342        }
 343    }
 344
 345    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 346        CustomerId::from_str(&existing_customer.stripe_customer_id)
 347            .context("failed to parse customer ID")?
 348    } else {
 349        let existing_customer = if let Some(email) = user.email_address.as_deref() {
 350            let customers = Customer::list(
 351                &stripe_client,
 352                &stripe::ListCustomers {
 353                    email: Some(email),
 354                    ..Default::default()
 355                },
 356            )
 357            .await?;
 358
 359            customers.data.first().cloned()
 360        } else {
 361            None
 362        };
 363
 364        if let Some(existing_customer) = existing_customer {
 365            existing_customer.id
 366        } else {
 367            let customer = Customer::create(
 368                &stripe_client,
 369                CreateCustomer {
 370                    email: user.email_address.as_deref(),
 371                    ..Default::default()
 372                },
 373            )
 374            .await?;
 375
 376            customer.id
 377        }
 378    };
 379
 380    let success_url = format!(
 381        "{}/account?checkout_complete=1",
 382        app.config.zed_dot_dev_url()
 383    );
 384
 385    let checkout_session_url = match body.product {
 386        ProductCode::ZedPro => {
 387            stripe_billing
 388                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
 389                .await?
 390        }
 391        ProductCode::ZedProTrial => {
 392            if let Some(existing_billing_customer) = &existing_billing_customer {
 393                if existing_billing_customer.trial_started_at.is_some() {
 394                    return Err(Error::http(
 395                        StatusCode::FORBIDDEN,
 396                        "user already used free trial".into(),
 397                    ));
 398                }
 399            }
 400
 401            let feature_flags = app.db.get_user_flags(user.id).await?;
 402
 403            stripe_billing
 404                .checkout_with_zed_pro_trial(
 405                    customer_id,
 406                    &user.github_login,
 407                    feature_flags,
 408                    &success_url,
 409                )
 410                .await?
 411        }
 412        ProductCode::ZedFree => {
 413            stripe_billing
 414                .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
 415                .await?
 416        }
 417    };
 418
 419    Ok(Json(CreateBillingSubscriptionResponse {
 420        checkout_session_url,
 421    }))
 422}
 423
 424#[derive(Debug, PartialEq, Deserialize)]
 425#[serde(rename_all = "snake_case")]
 426enum ManageSubscriptionIntent {
 427    /// The user intends to manage their subscription.
 428    ///
 429    /// This will open the Stripe billing portal without putting the user in a specific flow.
 430    ManageSubscription,
 431    /// The user intends to update their payment method.
 432    UpdatePaymentMethod,
 433    /// The user intends to upgrade to Zed Pro.
 434    UpgradeToPro,
 435    /// The user intends to cancel their subscription.
 436    Cancel,
 437    /// The user intends to stop the cancellation of their subscription.
 438    StopCancellation,
 439}
 440
 441#[derive(Debug, Deserialize)]
 442struct ManageBillingSubscriptionBody {
 443    github_user_id: i32,
 444    intent: ManageSubscriptionIntent,
 445    /// The ID of the subscription to manage.
 446    subscription_id: BillingSubscriptionId,
 447    redirect_to: Option<String>,
 448}
 449
 450#[derive(Debug, Serialize)]
 451struct ManageBillingSubscriptionResponse {
 452    billing_portal_session_url: Option<String>,
 453}
 454
 455/// Initiates a Stripe customer portal session for managing a billing subscription.
 456async fn manage_billing_subscription(
 457    Extension(app): Extension<Arc<AppState>>,
 458    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 459) -> Result<Json<ManageBillingSubscriptionResponse>> {
 460    let user = app
 461        .db
 462        .get_user_by_github_user_id(body.github_user_id)
 463        .await?
 464        .ok_or_else(|| anyhow!("user not found"))?;
 465
 466    let Some(stripe_client) = app.stripe_client.clone() else {
 467        log::error!("failed to retrieve Stripe client");
 468        Err(Error::http(
 469            StatusCode::NOT_IMPLEMENTED,
 470            "not supported".into(),
 471        ))?
 472    };
 473
 474    let Some(stripe_billing) = app.stripe_billing.clone() else {
 475        log::error!("failed to retrieve Stripe billing object");
 476        Err(Error::http(
 477            StatusCode::NOT_IMPLEMENTED,
 478            "not supported".into(),
 479        ))?
 480    };
 481
 482    let customer = app
 483        .db
 484        .get_billing_customer_by_user_id(user.id)
 485        .await?
 486        .ok_or_else(|| anyhow!("billing customer not found"))?;
 487    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 488        .context("failed to parse customer ID")?;
 489
 490    let subscription = app
 491        .db
 492        .get_billing_subscription_by_id(body.subscription_id)
 493        .await?
 494        .ok_or_else(|| anyhow!("subscription not found"))?;
 495    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 496        .context("failed to parse subscription ID")?;
 497
 498    if body.intent == ManageSubscriptionIntent::StopCancellation {
 499        let updated_stripe_subscription = Subscription::update(
 500            &stripe_client,
 501            &subscription_id,
 502            stripe::UpdateSubscription {
 503                cancel_at_period_end: Some(false),
 504                ..Default::default()
 505            },
 506        )
 507        .await?;
 508
 509        app.db
 510            .update_billing_subscription(
 511                subscription.id,
 512                &UpdateBillingSubscriptionParams {
 513                    stripe_cancel_at: ActiveValue::set(
 514                        updated_stripe_subscription
 515                            .cancel_at
 516                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 517                            .map(|time| time.naive_utc()),
 518                    ),
 519                    ..Default::default()
 520                },
 521            )
 522            .await?;
 523
 524        return Ok(Json(ManageBillingSubscriptionResponse {
 525            billing_portal_session_url: None,
 526        }));
 527    }
 528
 529    let flow = match body.intent {
 530        ManageSubscriptionIntent::ManageSubscription => None,
 531        ManageSubscriptionIntent::UpgradeToPro => {
 532            let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
 533            let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
 534
 535            let stripe_subscription =
 536                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 537
 538            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 539                && stripe_subscription.items.data.iter().any(|item| {
 540                    item.price
 541                        .as_ref()
 542                        .map_or(false, |price| price.id == zed_pro_price_id)
 543                });
 544            if is_on_zed_pro_trial {
 545                let payment_methods = PaymentMethod::list(
 546                    &stripe_client,
 547                    &stripe::ListPaymentMethods {
 548                        customer: Some(stripe_subscription.customer.id()),
 549                        ..Default::default()
 550                    },
 551                )
 552                .await?;
 553
 554                let has_payment_method = !payment_methods.data.is_empty();
 555                if !has_payment_method {
 556                    return Err(Error::http(
 557                        StatusCode::BAD_REQUEST,
 558                        "missing payment method".into(),
 559                    ));
 560                }
 561
 562                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 563                Subscription::update(
 564                    &stripe_client,
 565                    &stripe_subscription.id,
 566                    stripe::UpdateSubscription {
 567                        trial_end: Some(stripe::Scheduled::now()),
 568                        ..Default::default()
 569                    },
 570                )
 571                .await?;
 572
 573                return Ok(Json(ManageBillingSubscriptionResponse {
 574                    billing_portal_session_url: None,
 575                }));
 576            }
 577
 578            let subscription_item_to_update = stripe_subscription
 579                .items
 580                .data
 581                .iter()
 582                .find_map(|item| {
 583                    let price = item.price.as_ref()?;
 584
 585                    if price.id == zed_free_price_id {
 586                        Some(item.id.clone())
 587                    } else {
 588                        None
 589                    }
 590                })
 591                .ok_or_else(|| anyhow!("No subscription item to update"))?;
 592
 593            Some(CreateBillingPortalSessionFlowData {
 594                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 595                subscription_update_confirm: Some(
 596                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 597                        subscription: subscription.stripe_subscription_id,
 598                        items: vec![
 599                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 600                                id: subscription_item_to_update.to_string(),
 601                                price: Some(zed_pro_price_id.to_string()),
 602                                quantity: Some(1),
 603                            },
 604                        ],
 605                        discounts: None,
 606                    },
 607                ),
 608                ..Default::default()
 609            })
 610        }
 611        ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
 612            type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
 613            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 614                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 615                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 616                    return_url: format!(
 617                        "{}{path}",
 618                        app.config.zed_dot_dev_url(),
 619                        path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
 620                    ),
 621                }),
 622                ..Default::default()
 623            }),
 624            ..Default::default()
 625        }),
 626        ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
 627            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 628            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 629                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 630                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 631                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 632                }),
 633                ..Default::default()
 634            }),
 635            subscription_cancel: Some(
 636                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 637                    subscription: subscription.stripe_subscription_id,
 638                    retention: None,
 639                },
 640            ),
 641            ..Default::default()
 642        }),
 643        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 644    };
 645
 646    let mut params = CreateBillingPortalSession::new(customer_id);
 647    params.flow_data = flow;
 648    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 649    params.return_url = Some(&return_url);
 650
 651    let session = BillingPortalSession::create(&stripe_client, params).await?;
 652
 653    Ok(Json(ManageBillingSubscriptionResponse {
 654        billing_portal_session_url: Some(session.url),
 655    }))
 656}
 657
 658#[derive(Debug, Deserialize)]
 659struct MigrateToNewBillingBody {
 660    github_user_id: i32,
 661}
 662
 663#[derive(Debug, Serialize)]
 664struct MigrateToNewBillingResponse {
 665    /// The ID of the subscription that was canceled.
 666    canceled_subscription_id: Option<String>,
 667}
 668
 669async fn migrate_to_new_billing(
 670    Extension(app): Extension<Arc<AppState>>,
 671    extract::Json(body): extract::Json<MigrateToNewBillingBody>,
 672) -> Result<Json<MigrateToNewBillingResponse>> {
 673    let Some(stripe_client) = app.stripe_client.clone() else {
 674        log::error!("failed to retrieve Stripe client");
 675        Err(Error::http(
 676            StatusCode::NOT_IMPLEMENTED,
 677            "not supported".into(),
 678        ))?
 679    };
 680
 681    let user = app
 682        .db
 683        .get_user_by_github_user_id(body.github_user_id)
 684        .await?
 685        .ok_or_else(|| anyhow!("user not found"))?;
 686
 687    let old_billing_subscriptions_by_user = app
 688        .db
 689        .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
 690        .await?;
 691
 692    let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
 693        old_billing_subscriptions_by_user.get(&user.id)
 694    {
 695        let stripe_subscription_id = billing_subscription
 696            .stripe_subscription_id
 697            .parse::<stripe::SubscriptionId>()
 698            .context("failed to parse Stripe subscription ID from database")?;
 699
 700        Subscription::cancel(
 701            &stripe_client,
 702            &stripe_subscription_id,
 703            stripe::CancelSubscription {
 704                invoice_now: Some(true),
 705                ..Default::default()
 706            },
 707        )
 708        .await?;
 709
 710        Some(stripe_subscription_id)
 711    } else {
 712        None
 713    };
 714
 715    let all_feature_flags = app.db.list_feature_flags().await?;
 716    let user_feature_flags = app.db.get_user_flags(user.id).await?;
 717
 718    for feature_flag in ["new-billing", "assistant2"] {
 719        let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
 720        if already_in_feature_flag {
 721            continue;
 722        }
 723
 724        let feature_flag = all_feature_flags
 725            .iter()
 726            .find(|flag| flag.flag == feature_flag)
 727            .context("failed to find feature flag: {feature_flag:?}")?;
 728
 729        app.db.add_user_flag(user.id, feature_flag.id).await?;
 730    }
 731
 732    Ok(Json(MigrateToNewBillingResponse {
 733        canceled_subscription_id: canceled_subscription_id
 734            .map(|subscription_id| subscription_id.to_string()),
 735    }))
 736}
 737
 738#[derive(Debug, Deserialize)]
 739struct SyncBillingSubscriptionBody {
 740    github_user_id: i32,
 741}
 742
 743#[derive(Debug, Serialize)]
 744struct SyncBillingSubscriptionResponse {
 745    stripe_customer_id: String,
 746}
 747
 748async fn sync_billing_subscription(
 749    Extension(app): Extension<Arc<AppState>>,
 750    extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
 751) -> Result<Json<SyncBillingSubscriptionResponse>> {
 752    let Some(stripe_client) = app.stripe_client.clone() else {
 753        log::error!("failed to retrieve Stripe client");
 754        Err(Error::http(
 755            StatusCode::NOT_IMPLEMENTED,
 756            "not supported".into(),
 757        ))?
 758    };
 759
 760    let user = app
 761        .db
 762        .get_user_by_github_user_id(body.github_user_id)
 763        .await?
 764        .ok_or_else(|| anyhow!("user not found"))?;
 765
 766    let billing_customer = app
 767        .db
 768        .get_billing_customer_by_user_id(user.id)
 769        .await?
 770        .ok_or_else(|| anyhow!("billing customer not found"))?;
 771    let stripe_customer_id = billing_customer
 772        .stripe_customer_id
 773        .parse::<stripe::CustomerId>()
 774        .context("failed to parse Stripe customer ID from database")?;
 775
 776    let subscriptions = Subscription::list(
 777        &stripe_client,
 778        &stripe::ListSubscriptions {
 779            customer: Some(stripe_customer_id),
 780            // Sync all non-canceled subscriptions.
 781            status: None,
 782            ..Default::default()
 783        },
 784    )
 785    .await?;
 786
 787    for subscription in subscriptions.data {
 788        let subscription_id = subscription.id.clone();
 789
 790        sync_subscription(&app, &stripe_client, subscription)
 791            .await
 792            .with_context(|| {
 793                format!(
 794                    "failed to sync subscription {subscription_id} for user {}",
 795                    user.id,
 796                )
 797            })?;
 798    }
 799
 800    Ok(Json(SyncBillingSubscriptionResponse {
 801        stripe_customer_id: billing_customer.stripe_customer_id.clone(),
 802    }))
 803}
 804
 805/// The amount of time we wait in between each poll of Stripe events.
 806///
 807/// This value should strike a balance between:
 808///   1. Being short enough that we update quickly when something in Stripe changes
 809///   2. Being long enough that we don't eat into our rate limits.
 810///
 811/// As a point of reference, the Sequin folks say they have this at **500ms**:
 812///
 813/// > We poll the Stripe /events endpoint every 500ms per account
 814/// >
 815/// > — https://blog.sequinstream.com/events-not-webhooks/
 816const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 817
 818/// The maximum number of events to return per page.
 819///
 820/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 821///
 822/// > Limit can range between 1 and 100, and the default is 10.
 823const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 824
 825/// The number of pages consisting entirely of already-processed events that we
 826/// will see before we stop retrieving events.
 827///
 828/// This is used to prevent over-fetching the Stripe events API for events we've
 829/// already seen and processed.
 830const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 831
 832/// Polls the Stripe events API periodically to reconcile the records in our
 833/// database with the data in Stripe.
 834pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 835    let Some(stripe_client) = app.stripe_client.clone() else {
 836        log::warn!("failed to retrieve Stripe client");
 837        return;
 838    };
 839
 840    let executor = app.executor.clone();
 841    executor.spawn_detached({
 842        let executor = executor.clone();
 843        async move {
 844            loop {
 845                poll_stripe_events(&app, &rpc_server, &stripe_client)
 846                    .await
 847                    .log_err();
 848
 849                executor.sleep(POLL_EVENTS_INTERVAL).await;
 850            }
 851        }
 852    });
 853}
 854
 855async fn poll_stripe_events(
 856    app: &Arc<AppState>,
 857    rpc_server: &Arc<Server>,
 858    stripe_client: &stripe::Client,
 859) -> anyhow::Result<()> {
 860    fn event_type_to_string(event_type: EventType) -> String {
 861        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 862        // so we need to unquote it.
 863        event_type.to_string().trim_matches('"').to_string()
 864    }
 865
 866    let event_types = [
 867        EventType::CustomerCreated,
 868        EventType::CustomerUpdated,
 869        EventType::CustomerSubscriptionCreated,
 870        EventType::CustomerSubscriptionUpdated,
 871        EventType::CustomerSubscriptionPaused,
 872        EventType::CustomerSubscriptionResumed,
 873        EventType::CustomerSubscriptionDeleted,
 874    ]
 875    .into_iter()
 876    .map(event_type_to_string)
 877    .collect::<Vec<_>>();
 878
 879    let mut pages_of_already_processed_events = 0;
 880    let mut unprocessed_events = Vec::new();
 881
 882    log::info!(
 883        "Stripe events: starting retrieval for {}",
 884        event_types.join(", ")
 885    );
 886    let mut params = ListEvents::new();
 887    params.types = Some(event_types.clone());
 888    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 889
 890    let mut event_pages = stripe::Event::list(&stripe_client, &params)
 891        .await?
 892        .paginate(params);
 893
 894    loop {
 895        let processed_event_ids = {
 896            let event_ids = event_pages
 897                .page
 898                .data
 899                .iter()
 900                .map(|event| event.id.as_str())
 901                .collect::<Vec<_>>();
 902            app.db
 903                .get_processed_stripe_events_by_event_ids(&event_ids)
 904                .await?
 905                .into_iter()
 906                .map(|event| event.stripe_event_id)
 907                .collect::<Vec<_>>()
 908        };
 909
 910        let mut processed_events_in_page = 0;
 911        let events_in_page = event_pages.page.data.len();
 912        for event in &event_pages.page.data {
 913            if processed_event_ids.contains(&event.id.to_string()) {
 914                processed_events_in_page += 1;
 915                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 916            } else {
 917                unprocessed_events.push(event.clone());
 918            }
 919        }
 920
 921        if processed_events_in_page == events_in_page {
 922            pages_of_already_processed_events += 1;
 923        }
 924
 925        if event_pages.page.has_more {
 926            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 927            {
 928                log::info!(
 929                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 930                );
 931                break;
 932            } else {
 933                log::info!("Stripe events: retrieving next page");
 934                event_pages = event_pages.next(&stripe_client).await?;
 935            }
 936        } else {
 937            break;
 938        }
 939    }
 940
 941    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 942
 943    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 944    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 945
 946    for event in unprocessed_events {
 947        let event_id = event.id.clone();
 948        let processed_event_params = CreateProcessedStripeEventParams {
 949            stripe_event_id: event.id.to_string(),
 950            stripe_event_type: event_type_to_string(event.type_),
 951            stripe_event_created_timestamp: event.created,
 952        };
 953
 954        // If the event has happened too far in the past, we don't want to
 955        // process it and risk overwriting other more-recent updates.
 956        //
 957        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 958        let one_day = Duration::from_secs(24 * 60 * 60);
 959        let a_day_ago = Utc::now() - one_day;
 960        if a_day_ago.timestamp() > event.created {
 961            log::info!(
 962                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 963                event_id
 964            );
 965            app.db
 966                .create_processed_stripe_event(&processed_event_params)
 967                .await?;
 968
 969            return Ok(());
 970        }
 971
 972        let process_result = match event.type_ {
 973            EventType::CustomerCreated | EventType::CustomerUpdated => {
 974                handle_customer_event(app, stripe_client, event).await
 975            }
 976            EventType::CustomerSubscriptionCreated
 977            | EventType::CustomerSubscriptionUpdated
 978            | EventType::CustomerSubscriptionPaused
 979            | EventType::CustomerSubscriptionResumed
 980            | EventType::CustomerSubscriptionDeleted => {
 981                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 982            }
 983            _ => Ok(()),
 984        };
 985
 986        if let Some(()) = process_result
 987            .with_context(|| format!("failed to process event {event_id} successfully"))
 988            .log_err()
 989        {
 990            app.db
 991                .create_processed_stripe_event(&processed_event_params)
 992                .await?;
 993        }
 994    }
 995
 996    Ok(())
 997}
 998
 999async fn handle_customer_event(
1000    app: &Arc<AppState>,
1001    _stripe_client: &stripe::Client,
1002    event: stripe::Event,
1003) -> anyhow::Result<()> {
1004    let EventObject::Customer(customer) = event.data.object else {
1005        bail!("unexpected event payload for {}", event.id);
1006    };
1007
1008    log::info!("handling Stripe {} event: {}", event.type_, event.id);
1009
1010    let Some(email) = customer.email else {
1011        log::info!("Stripe customer has no email: skipping");
1012        return Ok(());
1013    };
1014
1015    let Some(user) = app.db.get_user_by_email(&email).await? else {
1016        log::info!("no user found for email: skipping");
1017        return Ok(());
1018    };
1019
1020    if let Some(existing_customer) = app
1021        .db
1022        .get_billing_customer_by_stripe_customer_id(&customer.id)
1023        .await?
1024    {
1025        app.db
1026            .update_billing_customer(
1027                existing_customer.id,
1028                &UpdateBillingCustomerParams {
1029                    // For now we just leave the information as-is, as it is not
1030                    // likely to change.
1031                    ..Default::default()
1032                },
1033            )
1034            .await?;
1035    } else {
1036        app.db
1037            .create_billing_customer(&CreateBillingCustomerParams {
1038                user_id: user.id,
1039                stripe_customer_id: customer.id.to_string(),
1040            })
1041            .await?;
1042    }
1043
1044    Ok(())
1045}
1046
1047async fn sync_subscription(
1048    app: &Arc<AppState>,
1049    stripe_client: &stripe::Client,
1050    subscription: stripe::Subscription,
1051) -> anyhow::Result<billing_customer::Model> {
1052    let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1053        stripe_billing
1054            .determine_subscription_kind(&subscription)
1055            .await
1056    } else {
1057        None
1058    };
1059
1060    let billing_customer =
1061        find_or_create_billing_customer(app, stripe_client, subscription.customer)
1062            .await?
1063            .ok_or_else(|| anyhow!("billing customer not found"))?;
1064
1065    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1066        if subscription.status == SubscriptionStatus::Trialing {
1067            let current_period_start =
1068                DateTime::from_timestamp(subscription.current_period_start, 0)
1069                    .ok_or_else(|| anyhow!("No trial subscription period start"))?;
1070
1071            app.db
1072                .update_billing_customer(
1073                    billing_customer.id,
1074                    &UpdateBillingCustomerParams {
1075                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1076                        ..Default::default()
1077                    },
1078                )
1079                .await?;
1080        }
1081    }
1082
1083    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1084        && subscription
1085            .cancellation_details
1086            .as_ref()
1087            .and_then(|details| details.reason)
1088            .map_or(false, |reason| {
1089                reason == CancellationDetailsReason::PaymentFailed
1090            });
1091
1092    if was_canceled_due_to_payment_failure {
1093        app.db
1094            .update_billing_customer(
1095                billing_customer.id,
1096                &UpdateBillingCustomerParams {
1097                    has_overdue_invoices: ActiveValue::set(true),
1098                    ..Default::default()
1099                },
1100            )
1101            .await?;
1102    }
1103
1104    if let Some(existing_subscription) = app
1105        .db
1106        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1107        .await?
1108    {
1109        app.db
1110            .update_billing_subscription(
1111                existing_subscription.id,
1112                &UpdateBillingSubscriptionParams {
1113                    billing_customer_id: ActiveValue::set(billing_customer.id),
1114                    kind: ActiveValue::set(subscription_kind),
1115                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1116                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1117                    stripe_cancel_at: ActiveValue::set(
1118                        subscription
1119                            .cancel_at
1120                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1121                            .map(|time| time.naive_utc()),
1122                    ),
1123                    stripe_cancellation_reason: ActiveValue::set(
1124                        subscription
1125                            .cancellation_details
1126                            .and_then(|details| details.reason)
1127                            .map(|reason| reason.into()),
1128                    ),
1129                    stripe_current_period_start: ActiveValue::set(Some(
1130                        subscription.current_period_start,
1131                    )),
1132                    stripe_current_period_end: ActiveValue::set(Some(
1133                        subscription.current_period_end,
1134                    )),
1135                },
1136            )
1137            .await?;
1138    } else {
1139        // If the user already has an active billing subscription, ignore the
1140        // event and return an `Ok` to signal that it was processed
1141        // successfully.
1142        //
1143        // There is the possibility that this could cause us to not create a
1144        // subscription in the following scenario:
1145        //
1146        //   1. User has an active subscription A
1147        //   2. User cancels subscription A
1148        //   3. User creates a new subscription B
1149        //   4. We process the new subscription B before the cancellation of subscription A
1150        //   5. User ends up with no subscriptions
1151        //
1152        // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1153        if app
1154            .db
1155            .has_active_billing_subscription(billing_customer.user_id)
1156            .await?
1157        {
1158            log::info!(
1159                "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1160                user_id = billing_customer.user_id,
1161                subscription_id = subscription.id
1162            );
1163            return Ok(billing_customer);
1164        }
1165
1166        app.db
1167            .create_billing_subscription(&CreateBillingSubscriptionParams {
1168                billing_customer_id: billing_customer.id,
1169                kind: subscription_kind,
1170                stripe_subscription_id: subscription.id.to_string(),
1171                stripe_subscription_status: subscription.status.into(),
1172                stripe_cancellation_reason: subscription
1173                    .cancellation_details
1174                    .and_then(|details| details.reason)
1175                    .map(|reason| reason.into()),
1176                stripe_current_period_start: Some(subscription.current_period_start),
1177                stripe_current_period_end: Some(subscription.current_period_end),
1178            })
1179            .await?;
1180    }
1181
1182    if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1183        if subscription.status == SubscriptionStatus::Canceled
1184            || subscription.status == SubscriptionStatus::Paused
1185        {
1186            let stripe_customer_id = billing_customer
1187                .stripe_customer_id
1188                .parse::<stripe::CustomerId>()
1189                .context("failed to parse Stripe customer ID from database")?;
1190
1191            stripe_billing
1192                .subscribe_to_zed_free(stripe_customer_id)
1193                .await?;
1194        }
1195    }
1196
1197    Ok(billing_customer)
1198}
1199
1200async fn handle_customer_subscription_event(
1201    app: &Arc<AppState>,
1202    rpc_server: &Arc<Server>,
1203    stripe_client: &stripe::Client,
1204    event: stripe::Event,
1205) -> anyhow::Result<()> {
1206    let EventObject::Subscription(subscription) = event.data.object else {
1207        bail!("unexpected event payload for {}", event.id);
1208    };
1209
1210    log::info!("handling Stripe {} event: {}", event.type_, event.id);
1211
1212    let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1213
1214    // When the user's subscription changes, push down any changes to their plan.
1215    rpc_server
1216        .update_plan_for_user(billing_customer.user_id)
1217        .await
1218        .trace_err();
1219
1220    // When the user's subscription changes, we want to refresh their LLM tokens
1221    // to either grant/revoke access.
1222    rpc_server
1223        .refresh_llm_tokens_for_user(billing_customer.user_id)
1224        .await;
1225
1226    Ok(())
1227}
1228
1229#[derive(Debug, Deserialize)]
1230struct GetMonthlySpendParams {
1231    github_user_id: i32,
1232}
1233
1234#[derive(Debug, Serialize)]
1235struct GetMonthlySpendResponse {
1236    monthly_free_tier_spend_in_cents: u32,
1237    monthly_free_tier_allowance_in_cents: u32,
1238    monthly_spend_in_cents: u32,
1239}
1240
1241async fn get_monthly_spend(
1242    Extension(app): Extension<Arc<AppState>>,
1243    Query(params): Query<GetMonthlySpendParams>,
1244) -> Result<Json<GetMonthlySpendResponse>> {
1245    let user = app
1246        .db
1247        .get_user_by_github_user_id(params.github_user_id)
1248        .await?
1249        .ok_or_else(|| anyhow!("user not found"))?;
1250
1251    let Some(llm_db) = app.llm_db.clone() else {
1252        return Err(Error::http(
1253            StatusCode::NOT_IMPLEMENTED,
1254            "LLM database not available".into(),
1255        ));
1256    };
1257
1258    let free_tier = user
1259        .custom_llm_monthly_allowance_in_cents
1260        .map(|allowance| Cents(allowance as u32))
1261        .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1262
1263    let spending_for_month = llm_db
1264        .get_user_spending_for_month(user.id, Utc::now())
1265        .await?;
1266
1267    let free_tier_spend = Cents::min(spending_for_month, free_tier);
1268    let monthly_spend = spending_for_month.saturating_sub(free_tier);
1269
1270    Ok(Json(GetMonthlySpendResponse {
1271        monthly_free_tier_spend_in_cents: free_tier_spend.0,
1272        monthly_free_tier_allowance_in_cents: free_tier.0,
1273        monthly_spend_in_cents: monthly_spend.0,
1274    }))
1275}
1276
1277#[derive(Debug, Deserialize)]
1278struct GetCurrentUsageParams {
1279    github_user_id: i32,
1280}
1281
1282#[derive(Debug, Serialize)]
1283struct UsageCounts {
1284    pub used: i32,
1285    pub limit: Option<i32>,
1286    pub remaining: Option<i32>,
1287}
1288
1289#[derive(Debug, Serialize)]
1290struct ModelRequestUsage {
1291    pub model: String,
1292    pub mode: CompletionMode,
1293    pub requests: i32,
1294}
1295
1296#[derive(Debug, Serialize)]
1297struct CurrentUsage {
1298    pub model_requests: UsageCounts,
1299    pub model_request_usage: Vec<ModelRequestUsage>,
1300    pub edit_predictions: UsageCounts,
1301}
1302
1303#[derive(Debug, Default, Serialize)]
1304struct GetCurrentUsageResponse {
1305    pub plan: String,
1306    pub current_usage: Option<CurrentUsage>,
1307}
1308
1309async fn get_current_usage(
1310    Extension(app): Extension<Arc<AppState>>,
1311    Query(params): Query<GetCurrentUsageParams>,
1312) -> Result<Json<GetCurrentUsageResponse>> {
1313    let user = app
1314        .db
1315        .get_user_by_github_user_id(params.github_user_id)
1316        .await?
1317        .ok_or_else(|| anyhow!("user not found"))?;
1318
1319    let feature_flags = app.db.get_user_flags(user.id).await?;
1320    let has_extended_trial = feature_flags
1321        .iter()
1322        .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1323
1324    let Some(llm_db) = app.llm_db.clone() else {
1325        return Err(Error::http(
1326            StatusCode::NOT_IMPLEMENTED,
1327            "LLM database not available".into(),
1328        ));
1329    };
1330
1331    let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1332        return Ok(Json(GetCurrentUsageResponse::default()));
1333    };
1334
1335    let subscription_period = maybe!({
1336        let period_start_at = subscription.current_period_start_at()?;
1337        let period_end_at = subscription.current_period_end_at()?;
1338
1339        Some((period_start_at, period_end_at))
1340    });
1341
1342    let Some((period_start_at, period_end_at)) = subscription_period else {
1343        return Ok(Json(GetCurrentUsageResponse::default()));
1344    };
1345
1346    let usage = llm_db
1347        .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1348        .await?;
1349
1350    let plan = usage
1351        .as_ref()
1352        .map(|usage| usage.plan.into())
1353        .unwrap_or_else(|| {
1354            subscription
1355                .kind
1356                .map(Into::into)
1357                .unwrap_or(zed_llm_client::Plan::ZedFree)
1358        });
1359
1360    let model_requests_limit = match plan.model_requests_limit() {
1361        zed_llm_client::UsageLimit::Limited(limit) => {
1362            let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1363                1_000
1364            } else {
1365                limit
1366            };
1367
1368            Some(limit)
1369        }
1370        zed_llm_client::UsageLimit::Unlimited => None,
1371    };
1372
1373    let edit_predictions_limit = match plan.edit_predictions_limit() {
1374        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1375        zed_llm_client::UsageLimit::Unlimited => None,
1376    };
1377
1378    let Some(usage) = usage else {
1379        return Ok(Json(GetCurrentUsageResponse {
1380            plan: plan.as_str().to_string(),
1381            current_usage: Some(CurrentUsage {
1382                model_requests: UsageCounts {
1383                    used: 0,
1384                    limit: model_requests_limit,
1385                    remaining: model_requests_limit,
1386                },
1387                model_request_usage: Vec::new(),
1388                edit_predictions: UsageCounts {
1389                    used: 0,
1390                    limit: edit_predictions_limit,
1391                    remaining: edit_predictions_limit,
1392                },
1393            }),
1394        }));
1395    };
1396
1397    let subscription_usage_meters = llm_db
1398        .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1399        .await?;
1400
1401    let model_request_usage = subscription_usage_meters
1402        .into_iter()
1403        .filter_map(|(usage_meter, _usage)| {
1404            let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1405
1406            Some(ModelRequestUsage {
1407                model: model.name.clone(),
1408                mode: usage_meter.mode,
1409                requests: usage_meter.requests,
1410            })
1411        })
1412        .collect::<Vec<_>>();
1413
1414    Ok(Json(GetCurrentUsageResponse {
1415        plan: plan.as_str().to_string(),
1416        current_usage: Some(CurrentUsage {
1417            model_requests: UsageCounts {
1418                used: usage.model_requests,
1419                limit: model_requests_limit,
1420                remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1421            },
1422            model_request_usage,
1423            edit_predictions: UsageCounts {
1424                used: usage.edit_predictions,
1425                limit: edit_predictions_limit,
1426                remaining: edit_predictions_limit
1427                    .map(|limit| (limit - usage.edit_predictions).max(0)),
1428            },
1429        }),
1430    }))
1431}
1432
1433impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1434    fn from(value: SubscriptionStatus) -> Self {
1435        match value {
1436            SubscriptionStatus::Incomplete => Self::Incomplete,
1437            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1438            SubscriptionStatus::Trialing => Self::Trialing,
1439            SubscriptionStatus::Active => Self::Active,
1440            SubscriptionStatus::PastDue => Self::PastDue,
1441            SubscriptionStatus::Canceled => Self::Canceled,
1442            SubscriptionStatus::Unpaid => Self::Unpaid,
1443            SubscriptionStatus::Paused => Self::Paused,
1444        }
1445    }
1446}
1447
1448impl From<CancellationDetailsReason> for StripeCancellationReason {
1449    fn from(value: CancellationDetailsReason) -> Self {
1450        match value {
1451            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1452            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1453            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1454        }
1455    }
1456}
1457
1458/// Finds or creates a billing customer using the provided customer.
1459async fn find_or_create_billing_customer(
1460    app: &Arc<AppState>,
1461    stripe_client: &stripe::Client,
1462    customer_or_id: Expandable<Customer>,
1463) -> anyhow::Result<Option<billing_customer::Model>> {
1464    let customer_id = match &customer_or_id {
1465        Expandable::Id(id) => id,
1466        Expandable::Object(customer) => customer.id.as_ref(),
1467    };
1468
1469    // If we already have a billing customer record associated with the Stripe customer,
1470    // there's nothing more we need to do.
1471    if let Some(billing_customer) = app
1472        .db
1473        .get_billing_customer_by_stripe_customer_id(customer_id)
1474        .await?
1475    {
1476        return Ok(Some(billing_customer));
1477    }
1478
1479    // If all we have is a customer ID, resolve it to a full customer record by
1480    // hitting the Stripe API.
1481    let customer = match customer_or_id {
1482        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1483        Expandable::Object(customer) => *customer,
1484    };
1485
1486    let Some(email) = customer.email else {
1487        return Ok(None);
1488    };
1489
1490    let Some(user) = app.db.get_user_by_email(&email).await? else {
1491        return Ok(None);
1492    };
1493
1494    let billing_customer = app
1495        .db
1496        .create_billing_customer(&CreateBillingCustomerParams {
1497            user_id: user.id,
1498            stripe_customer_id: customer.id.to_string(),
1499        })
1500        .await?;
1501
1502    Ok(Some(billing_customer))
1503}
1504
1505const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1506
1507pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1508    let Some(stripe_billing) = app.stripe_billing.clone() else {
1509        log::warn!("failed to retrieve Stripe billing object");
1510        return;
1511    };
1512    let Some(llm_db) = app.llm_db.clone() else {
1513        log::warn!("failed to retrieve LLM database");
1514        return;
1515    };
1516
1517    let executor = app.executor.clone();
1518    executor.spawn_detached({
1519        let executor = executor.clone();
1520        async move {
1521            loop {
1522                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1523                    .await
1524                    .context("failed to sync LLM request usage to Stripe")
1525                    .trace_err();
1526                executor
1527                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1528                    .await;
1529            }
1530        }
1531    });
1532}
1533
1534async fn sync_model_request_usage_with_stripe(
1535    app: &Arc<AppState>,
1536    llm_db: &Arc<LlmDatabase>,
1537    stripe_billing: &Arc<StripeBilling>,
1538) -> anyhow::Result<()> {
1539    let staff_users = app.db.get_staff_users().await?;
1540    let staff_user_ids = staff_users
1541        .iter()
1542        .map(|user| user.id)
1543        .collect::<HashSet<UserId>>();
1544
1545    let usage_meters = llm_db
1546        .get_current_subscription_usage_meters(Utc::now())
1547        .await?;
1548    let usage_meters = usage_meters
1549        .into_iter()
1550        .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1551        .collect::<Vec<_>>();
1552    let user_ids = usage_meters
1553        .iter()
1554        .map(|(_, usage)| usage.user_id)
1555        .collect::<HashSet<UserId>>();
1556    let billing_subscriptions = app
1557        .db
1558        .get_active_zed_pro_billing_subscriptions(user_ids)
1559        .await?;
1560
1561    let claude_3_5_sonnet = stripe_billing
1562        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1563        .await?;
1564    let claude_3_7_sonnet = stripe_billing
1565        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1566        .await?;
1567    let claude_3_7_sonnet_max = stripe_billing
1568        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1569        .await?;
1570
1571    for (usage_meter, usage) in usage_meters {
1572        maybe!(async {
1573            let Some((billing_customer, billing_subscription)) =
1574                billing_subscriptions.get(&usage.user_id)
1575            else {
1576                bail!(
1577                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1578                    usage.user_id
1579                );
1580            };
1581
1582            let stripe_customer_id = billing_customer
1583                .stripe_customer_id
1584                .parse::<stripe::CustomerId>()
1585                .context("failed to parse Stripe customer ID from database")?;
1586            let stripe_subscription_id = billing_subscription
1587                .stripe_subscription_id
1588                .parse::<stripe::SubscriptionId>()
1589                .context("failed to parse Stripe subscription ID from database")?;
1590
1591            let model = llm_db.model_by_id(usage_meter.model_id)?;
1592
1593            let (price, meter_event_name) = match model.name.as_str() {
1594                "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1595                "claude-3-7-sonnet" => match usage_meter.mode {
1596                    CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1597                    CompletionMode::Max => {
1598                        (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1599                    }
1600                },
1601                model_name => {
1602                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1603                }
1604            };
1605
1606            stripe_billing
1607                .subscribe_to_price(&stripe_subscription_id, price)
1608                .await?;
1609            stripe_billing
1610                .bill_model_request_usage(
1611                    &stripe_customer_id,
1612                    meter_event_name,
1613                    usage_meter.requests,
1614                )
1615                .await?;
1616
1617            Ok(())
1618        })
1619        .await
1620        .log_err();
1621    }
1622
1623    Ok(())
1624}